[CK_TILE] Add logits soft-capping & customization support to the FMHA forward kernel/pipelines (#2163)

* hack for cap logits

* fix bug

* Re-format files

* Allow specifying logits_soft_cap through APIs

* Support turn on/off logits_soft_cap in async pipeline

* Do not generate non-verified kernels

* Align receipt used in Aiter

* Sync logits soft-capping across pipelines

* Re-enable some hdim pipelines

* fix perf

* Add attention variant for logits_soft_cap

* Add newline at end-of-file

* Fix performance

* Add comment to explain logits_soft_cap pre-processing

* Unify code

* Unify floating-point literal style

* Use class data member to slience the compilation error

* [CK_TILE] Update attention customizaton interface: add LogitsMask() (#2133)

* Send 'mask' along with variant params to the LogitsMask()

* Send block indices to the variant

* Add indices parameters in variant interface

* Fix fmha bwd codegen error

* Allow switch logits_soft_cap impl

* Eliminate register spills

* Fix compilation errors

* Fix wrong LSE

* Fix LSE for splitkv kernel

* Sync splitkv pipeline changes

* Add batch_prefill kernel/pipeline

* Fix codegen error

* Undo changes in CMakeLists.txt

* Merge pipeline filtering check

* Use different code path if kHasLogitsSoftCap=false

* Remove [[maybe_unused]] attribute

* Use pre-existing compile-time flag to instantiate templates

* Sync pipeline changes

* Update CHANGELOG.md

---------

Co-authored-by: Bernard <bernaliu@amd.com>
Co-authored-by: coderfeli <coderfeli@163.com>

[ROCm/composable_kernel commit: 2920604786]
This commit is contained in:
Po Yen Chen
2025-05-13 12:19:25 +08:00
committed by GitHub
parent 699091846a
commit c3f74a83b3
29 changed files with 4621 additions and 226 deletions

View File

@@ -16,6 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added GEMM pipeline for microscaling (MX) data types
* Added support for FP16 2:4 structured sparsity to universal GEMM.
* Added support for Split K for grouped convolution backward data.
* Added logit soft-capping support for fMHA forward kernels.
### Optimized

View File

@@ -114,12 +114,14 @@ LAYOUT_MAP = {
PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
}
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
}
BOOL_MAP = {

View File

@@ -0,0 +1,595 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
"qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_logits},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_occupancy}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
template<>
float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_batch_prefill_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
FMHA_FWD_API="""
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
dropout : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False
@dataclass
class FmhaFwdPipeline:
tag : str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_logits == 't' : n += '_logits'
else: n += '_nlogits'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
else: n += '_nmask'
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_dropout == 't' : n += '_dropout'
else: n += '_ndropout'
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
@dataclass
class FmhaFwdTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass
class FmhaFwdKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0max = self.F_tile.F_bk0max,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag])
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
### '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
### '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
}
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
continue
k = FmhaFwdKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'batch'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")

View File

@@ -60,6 +60,7 @@ using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
false,
{F_bias},
{F_dbias},
false,

View File

@@ -32,6 +32,7 @@ K0_MAX_SUBMAX_MAP = {
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd.hpp"
"""
@@ -51,12 +52,16 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_logits},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_occupancy}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
@@ -73,6 +78,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
fmha_trait_{F_idx}>;
@@ -88,7 +94,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
@@ -123,9 +129,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
@@ -144,6 +150,7 @@ class FmhaFwdApiTrait:
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
@@ -157,7 +164,7 @@ class FmhaFwdApiTrait:
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
@property
def scheck(self) -> str:
@@ -165,7 +172,7 @@ class FmhaFwdApiTrait:
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr']:
elif self.pipeline_tag in ['qr', 'qs']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@@ -176,7 +183,7 @@ class FmhaFwdApiTrait:
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
elif self.pipeline_tag in ['qr', 'qs']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
@@ -187,7 +194,7 @@ class FmhaFwdApiTrait:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
elif self.pipeline_tag in ['qr', 'qs']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
@@ -199,7 +206,7 @@ class FmhaFwdApiTrait:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
elif self.pipeline_tag in ['qr', 'qs']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
@@ -214,6 +221,7 @@ class FmhaFwdPipeline:
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_dropout : str #
@@ -235,6 +243,9 @@ class FmhaFwdPipeline:
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_logits == 't' : n += '_logits'
else: n += '_nlogits'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
@@ -280,7 +291,7 @@ class FmhaFwdApiPool:
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
@@ -365,6 +376,7 @@ class FmhaFwdKernel:
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
@@ -399,6 +411,7 @@ class FmhaFwdKernel:
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
@@ -440,36 +453,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
@@ -497,6 +510,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
continue
k = FmhaFwdKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,

View File

@@ -45,6 +45,7 @@ FMHA_FWD_SPLITKV_PIPELINE_MAP = {
FMHA_FWD_SPLITKV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
namespace {{
@@ -63,6 +64,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_logits},
{F_bias},
/*kHasBiasGrad=*/false,
{F_lse},
@@ -85,6 +87,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
fmha_shape,
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
fmha_trait>;
@@ -113,7 +116,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}}
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#include <iostream>
@@ -267,9 +270,9 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
}}
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
@@ -310,6 +313,7 @@ class FmhaFwdSplitKVApiTrait:
bk0max : int
vlayout : str
mask : str
logits : str
bias : str #
lse : str #
squant : str #
@@ -322,7 +326,7 @@ class FmhaFwdSplitKVApiTrait:
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
f'{self.dvpad}-{self.pagedkv}'
@property
@@ -380,6 +384,7 @@ class FmhaFwdSplitKVPipeline:
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_squant : str #
@@ -401,6 +406,9 @@ class FmhaFwdSplitKVPipeline:
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_logits == 't' : n += '_logits'
else: n += '_nlogits'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
@@ -475,7 +483,7 @@ class FmhaFwdSplitKVApiPool:
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
@@ -541,6 +549,7 @@ class FmhaFwdSplitKVKernel:
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
@@ -574,6 +583,7 @@ class FmhaFwdSplitKVKernel:
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
logits=self.F_pipeline.F_logits,
mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
@@ -671,32 +681,32 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
# TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
# if True:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
else:
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
if receipt == 1:
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask))
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
@@ -720,6 +730,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
continue
k = Kernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,

View File

@@ -11,6 +11,7 @@
#include <array>
#include <cstring>
#include <functional>
#include <cmath>
#include <numeric>
#include <ostream>
#include <string>
@@ -72,6 +73,7 @@ auto create_args(int argc, char* argv[])
"0",
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
"note when squant=1, this value will be modified by range_q/k")
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
.insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.")
.insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.")
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
@@ -416,6 +418,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(scale_s == .0f)
scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
const float logits_soft_cap = arg_parser.get_float("logits_soft_cap");
std::string squant_str = arg_parser.get_str("squant");
bool squant = [&]() {
if(squant_str == "auto")
@@ -850,6 +854,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else // fmha_fwd_traits or fmha_splitkv_traits
{
traits.is_group_mode = (mode == mode_enum::group);
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.has_lse = lse;
@@ -1007,6 +1012,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.scale_p = scale_p;
args.scale_o = scale_o;
args.logits_soft_cap = logits_soft_cap;
args.stride_bias =
(bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias);
args.stride_o = stride_o;
@@ -1375,6 +1382,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::identity{},
ck_tile::scales(scale_s));
if(0.f < logits_soft_cap)
{
ck_tile::reference_unary_elementwise<SaccDataType, SaccDataType, SaccDataType>(
s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) {
return ck_tile::type_convert<SaccDataType>(
logits_soft_cap *
std::tanhf(ck_tile::type_convert<float>(logits / logits_soft_cap)));
});
}
if(bias.type == bias_enum::elementwise_bias)
{
// elementwise bias

View File

@@ -143,6 +143,8 @@ struct fmha_fwd_args
float scale_p;
float scale_o;
float logits_soft_cap;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
@@ -232,6 +234,8 @@ struct fmha_fwd_splitkv_args
float scale_p;
float scale_o;
float logits_soft_cap;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
@@ -308,6 +312,85 @@ struct fmha_fwd_appendkv_args
ck_tile::index_t batch_stride_vnew;
};
struct fmha_batch_prefill_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* rand_val_ptr;
void* lse_ptr;
void* o_ptr;
// the real seqlen_q & seqlen_k are decided by following:
// batch mode (kvcache):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.page_block_size * (kargs.kv_indptr[b + 1] - kargs.kv_indptr[b] -
// 1) +
// kargs.kv_last_page_lens[b]
// group mode (kvcache):
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.page_block_size * (kargs.kv_indptr[b + 1] - kargs.kv_indptr[b] -
// 1) +
// kargs.kv_last_page_lens[b]
const void* seqstart_q_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
// SGLang-style page table
int32_t num_total_pages;
void* kv_indptr;
void* kv_page_indices;
#if 0 // we assume page_block_size=1 for now
void* kv_last_page_lens;
ck_tile::index_t page_block_size;
#endif
float scale_s;
float scale_p;
float scale_o;
float logits_soft_cap;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
bool s_randval;
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
};
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
{
@@ -333,6 +416,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.scale_s,
args.scale_p,
args.scale_o,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
@@ -371,6 +455,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.scale_s,
args.scale_p,
args.scale_o,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
@@ -443,6 +528,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.is_gappy,
args.scale_s,
args.scale_p,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
@@ -485,6 +571,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.cache_batch_idx,
args.scale_s,
args.scale_p,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
@@ -618,6 +705,117 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaKernel>
auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr,
args.o_ptr,
args.seqstart_q_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_total_pages,
args.kv_indptr,
args.kv_page_indices,
#if 0 // we assume page_block_size=1 for now
args.kv_last_page_lens,
args.page_block_size,
#endif
args.scale_s,
args.scale_p,
args.scale_o,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_k,
args.batch_stride_v,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr,
args.o_ptr,
args.seqlen_q,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_total_pages,
args.kv_indptr,
args.kv_page_indices,
#if 0 // we assume page_block_size=1 for now
args.kv_last_page_lens,
args.page_block_size,
#endif
args.scale_s,
args.scale_p,
args.scale_o,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse,
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
@@ -630,6 +828,7 @@ template <ck_tile::index_t HDim_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
bool kHasLogitsSoftCap_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
@@ -652,6 +851,7 @@ struct fmha_fwd_traits_
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_;
@@ -677,6 +877,7 @@ template <ck_tile::index_t HDim_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
bool kHasLogitsSoftCap_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
@@ -699,6 +900,7 @@ struct fmha_fwd_splitkv_traits_
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_;
@@ -776,6 +978,9 @@ struct fmha_fwd_appendkv_traits_
template <typename Traits_>
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
template <typename Traits_>
float fmha_batch_prefill_(const ck_tile::stream_config&, fmha_batch_prefill_args);
// This is the public API, will be generated by script
struct fmha_fwd_traits
{
@@ -784,6 +989,7 @@ struct fmha_fwd_traits
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
bool has_logits_soft_cap;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
@@ -800,6 +1006,7 @@ struct fmha_fwd_splitkv_traits
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
bool has_logits_soft_cap;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
@@ -821,3 +1028,8 @@ struct fmha_fwd_appendkv_traits
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args,
const ck_tile::stream_config&);
using fmha_batch_prefill_traits = fmha_fwd_traits;
float fmha_batch_prefill(fmha_batch_prefill_traits,
fmha_batch_prefill_args,
const ck_tile::stream_config&);

View File

@@ -21,8 +21,7 @@ class HandlerId(IntEnum):
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
if full_module_name not in sys.modules:
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = 'fmha_'
handlers = dict(
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,

View File

@@ -54,6 +54,7 @@
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"

View File

@@ -487,6 +487,9 @@ struct log2e<float>
template <typename T = double>
constexpr T log2e_v = log2e<T>::value;
template <typename T = double>
constexpr T log2e_rcp_v = 1. / log2e<T>::value;
CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); };
@@ -1380,6 +1383,44 @@ CK_TILE_DEVICE double exp<double>(double x)
return exp(x);
};
template <typename T>
CK_TILE_DEVICE T tanh_fast(T x)
{
return type_convert<T>((exp<T>(2.0 * type_convert<float>(x)) - 1.0) /
(exp<T>(2.0 * type_convert<float>(x)) + 1.0));
};
template <>
CK_TILE_DEVICE float tanh_fast<float>(float x)
{
// float a = __builtin_amdgcn_sinh(x);
// float b = __builtin_amdgcn_cosh(x);
// float e = a * __builtin_amdgcn_rcpf(b);
// return e;
float a = 2.0f * log2e_v<float> * x;
a = __builtin_amdgcn_exp2f(a);
a = __builtin_amdgcn_rcpf(a + 1.0f);
a = 2 * a;
a = 1 - a;
return a;
// float e, r, s, t, d;
// float a = x;
// s = abs(a);
// t = -log2e_v<float> * 2.0f * s;
// e = __builtin_amdgcn_exp2f(t);
// d = e + 1.0f;
// r = __builtin_amdgcn_rcpf(d);
// r = e * (-r) + r;
// if (s < 4.997253418e-3f) r = a;
// union fipnr {float f; unsigned int i;};
// fipnr r_; r_.f = r;
// fipnr a_; a_.f = a;
// { r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; }
// return r;
};
template <typename T>
CK_TILE_DEVICE T log(T x)
{

View File

@@ -18,32 +18,8 @@
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
template <typename TileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
@@ -51,35 +27,11 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
@@ -138,42 +90,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto
async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})

View File

@@ -210,6 +210,27 @@ struct tensor_view
bool_constant<pre_nop>{});
}
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t coord_extra_offset,
index_t linear_offset,
bool_constant<pre_nop> = {}) const
{
return buf_.template async_get_raw<X>(
smem,
(coord.get_offset() + coord_extra_offset) / PackedSize,
linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}
template <typename X,
bool pre_nop = false,
typename std::enable_if<

View File

@@ -0,0 +1,731 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
/**
* @brief This class provides tile (windowed) view and access to the device memory.
*
* @note This tile window does not support single issue you need to use tile_window_linear
* structure for this purpose
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
* @tparam NumCoord TBD
*/
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1,
index_t YsGatherDim = 0>
struct tile_scatter_gather
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using PageIdxArray = remove_cvref_t<StaticPageIndexArray_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static_assert(NumCoord == 1);
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileDstr::is_static(), "wrong!");
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
"wrong! inconsistent # of diemsnions");
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
struct load_store_traits
{
private:
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
tile_scatter_gather::get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type;
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
private:
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto get_space_filling_curve()
{
constexpr auto tile_dstr = TileDstr{};
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_)>{};
}
public:
using SFC_Ys = decltype(get_space_filling_curve());
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
};
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
CK_TILE_DEVICE constexpr tile_scatter_gather(const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution,
const PageIdxArray& page_idx)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
page_idx_{page_idx},
pre_computed_coords_{}
{
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_distribution),
array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
{
return TileDstr::is_static();
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
window_adaptor_thread_coord,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_coord,
idx_diff_adaptor_bottom);
}
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
-1};
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
-1};
constexpr auto window_adaptor_bottom_dims =
WindowAdaptor::get_bottom_dimension_hidden_ids();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
bool_constant<oob_conditional_check>{});
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
});
#else
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
using LdsDataType = typename LdsTileWindow::DataType;
// using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
m0_inc_with_memory(size_per_issue);
}
});
});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// printf("off %d\n", page_idx_[I0]);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<0>{}];
const auto page_offset = page_idx_[idx_gather];
// printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
// idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
// read from distributed tensor
// vector_type_t vec;
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
// printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
vec_value,
bool_constant<oob_conditional_check>{});
// printf("coord_offset:%d, scatter_offset:%d \n",
// bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
{
window_origin_ += step;
BottomTensorIndex step_new = step;
step_new(HsGatherDim) = 0;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
pre_computed_coords_(iCoord)(I1),
step_new);
});
}
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx)
{
page_idx_ = new_idx;
// static_for<0, 2, 1>{}([&](auto k0) {
// printf("update tid %d %d \n", threadIdx.x, page_idx_[k0]);
// });
}
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
PageIdxArray page_idx_;
// this contains:
// per-thread coordinate for window adaptor
// per-thread coordinate for bottom tensor
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
};
// TODO: use strategy
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
number<HsGatherDim> = {},
number<NumCoord> = {})
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
HsGatherDim,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution, page_idx};
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,
typename StaticPageIndexArray,
index_t HsGatherDim>
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx,
number<HsGatherDim> = {})
{
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
origin,
tile_distribution,
page_idx,
number<HsGatherDim>{});
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,
typename StaticPageIndexArray,
index_t HsGatherDim>
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx,
number<HsGatherDim> = {})
{
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
page_idx,
number<HsGatherDim>{});
}
} // namespace ck_tile

View File

@@ -18,6 +18,13 @@
#pragma once
namespace ck_tile {
template <typename TileWindow_>
CK_TILE_DEVICE void move_tile_window(TileWindow_& window,
const typename TileWindow_::BottomTensorIndex& step)
{
window.move(step);
}
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template <typename LdsTileWindow_>

View File

@@ -9,12 +9,16 @@
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/block/page_block_navigator.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"

View File

@@ -0,0 +1,274 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <type_traits>
#include <ck_tile/core/numeric/math.hpp>
#include <ck_tile/core/numeric/type_convert.hpp>
#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0
#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1
#ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT
#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
#endif
namespace ck_tile {
template <typename ImplMask>
struct StandardAttentionParams
{
__device__ __host__ StandardAttentionParams(const ImplMask& impl_mask_, float sm_scale_)
: impl_mask(impl_mask_), sm_scale(sm_scale_)
{
}
const ImplMask& impl_mask;
float sm_scale;
};
template <typename ImplMask, bool UseExp2 = false>
struct LogitsSoftCapParams
{
__device__
LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_)
: impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_)
{
if(0.f < logits_soft_cap)
{
logits_soft_cap_rcp = __builtin_amdgcn_rcpf(logits_soft_cap);
}
else
{
logits_soft_cap_rcp = 0.f;
}
// move computation here to prevent compiler from generating inefficient instruction
// sequence
if constexpr(UseExp2)
{
logits_soft_cap = log2e_v<float> * logits_soft_cap;
logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
}
}
__host__
LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_)
: impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_)
{
if(0.f < logits_soft_cap)
{
logits_soft_cap_rcp = 1.f / logits_soft_cap;
}
else
{
logits_soft_cap_rcp = 0.f;
}
// move computation here to prevent compiler from generating inefficient instruction
// sequence
if constexpr(UseExp2)
{
logits_soft_cap = log2e_v<float> * logits_soft_cap;
logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
}
}
__device__ __host__ LogitsSoftCapParams(const ImplMask& impl_mask_,
float sm_scale_,
float logits_soft_cap_,
float logits_soft_cap_rcp_)
: impl_mask(impl_mask_),
sm_scale(sm_scale_),
logits_soft_cap(logits_soft_cap_),
logits_soft_cap_rcp(logits_soft_cap_rcp_)
{
// move computation here to prevent compiler from generating inefficient instruction
// sequence
if constexpr(UseExp2)
{
logits_soft_cap = log2e_v<float> * logits_soft_cap;
logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
}
}
const ImplMask& impl_mask;
float sm_scale;
float logits_soft_cap;
float logits_soft_cap_rcp;
};
struct StandardAttention
{
__device__ __host__ StandardAttention() = default;
template <typename Params, typename T>
__device__ __forceinline__ T QueryTransform(const Params& params, T q) const
{
return type_convert<float>(q) * params.sm_scale;
}
/// NOTICE: For better performance, we simpliy transform thread buffer without calculating
/// qo_idx/kv_idx.
template <typename Params, typename T>
__device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params& params,
T logits,
[[maybe_unused]] uint32_t batch_idx,
/*uint32_t qo_idx, uint32_t kv_idx,*/
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return logits;
}
template <typename Params>
__device__ __forceinline__ bool LogitsMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
};
template <bool UseExp2 = false>
struct LogitsSoftCap
{
__device__ __host__ LogitsSoftCap() = default;
template <typename Params, typename T>
__device__ __forceinline__ T QueryTransform(const Params& params, T q) const
{
if constexpr(UseExp2)
{
return q;
}
else
{
return type_convert<float>(q) * params.sm_scale;
}
}
/// NOTICE: For better performance, we simpliy transform thread buffer without calculating
/// qo_idx/kv_idx.
template <typename Params, typename T>
__device__ __forceinline__ T LogitsTransform(const Params& params,
T logits,
[[maybe_unused]] uint32_t batch_idx,
/*uint32_t qo_idx, uint32_t kv_idx,*/
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
if constexpr(UseExp2)
{
#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
return params.logits_soft_cap *
tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
return params.sm_scale * type_convert<float>(logits) *
rcp<float>(1.f + abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
#endif
}
else
{
#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
return params.logits_soft_cap *
tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
return type_convert<float>(logits) *
rcp<float>(1.f + abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
#endif
}
}
template <typename Params>
__device__ __forceinline__ bool LogitsMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
};
constexpr uint32_t CUSTOM_MASK = 1U;
constexpr uint32_t SLIDING_WINDOW = 2U;
constexpr uint32_t LOGITS_SOFT_CAP = 4U;
constexpr uint32_t ALIBI = 8U;
template <uint32_t VARIANT_CODE, bool UseExp2 = false>
struct ComposedAttention
{
static constexpr bool use_exp2 = UseExp2;
static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0;
__device__ __host__ ComposedAttention() = default;
template <typename Params, typename T>
__device__ __forceinline__ T QueryTransform(const Params& params, T q) const
{
if constexpr(use_logits_soft_cap && UseExp2)
{
return q;
}
return type_convert<float>(q) * params.sm_scale;
}
/// NOTICE: For better performance, we simpliy transform thread buffer without calculating
/// qo_idx/kv_idx.
template <typename Params, typename T>
__device__ __forceinline__ T LogitsTransform(const Params& params,
T logits,
[[maybe_unused]] uint32_t batch_idx,
/*uint32_t qo_idx, uint32_t kv_idx,*/
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
if constexpr(use_logits_soft_cap)
{
if constexpr(UseExp2)
{
#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
return params.logits_soft_cap *
tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
return params.sm_scale * type_convert<float>(logits) *
rcp<float>(1.f +
abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
#endif
}
else
{
#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
return params.logits_soft_cap *
tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
return type_convert<float>(logits) *
rcp<float>(1.f +
abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
#endif
}
}
return logits;
}
template <typename Params>
__device__ __forceinline__ bool LogitsMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <string>
#include <type_traits>
@@ -47,11 +48,13 @@ struct FmhaFwdKernel
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
@@ -94,7 +97,7 @@ struct FmhaFwdKernel
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" );
#undef _SS_
#undef _TS_
@@ -139,6 +142,28 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o;
};
struct FmhaFwdLogitsSoftCapKargs
{
FmhaFwdLogitsSoftCapKargs() = default;
void init_logits_soft_cap(float logits_soft_cap_)
{
if(0 < logits_soft_cap_)
{
logits_soft_cap = logits_soft_cap_;
logits_soft_cap_rcp = 1.f / logits_soft_cap;
}
else
{
logits_soft_cap = 0.f;
logits_soft_cap_rcp = 0.f;
}
}
float logits_soft_cap;
float logits_soft_cap_rcp;
};
struct FmhaFwdCommonBiasKargs
{
const void* bias_ptr = nullptr;
@@ -242,7 +267,8 @@ struct FmhaFwdKernel
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
@@ -260,7 +286,8 @@ struct FmhaFwdKernel
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -269,6 +296,13 @@ struct FmhaFwdKernel
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
struct BlockIndices
{
ck_tile::index_t batch_idx;
ck_tile::index_t qo_head_idx;
ck_tile::index_t kv_head_idx;
};
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
@@ -287,6 +321,7 @@ struct FmhaFwdKernel
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
@@ -343,6 +378,7 @@ struct FmhaFwdKernel
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout
{}, // placeholder for logits_soft_cap
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -398,6 +434,10 @@ struct FmhaFwdKernel
kargs.batch_stride_randval = batch_stride_randval;
kargs.is_store_randval = s_randval;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
}
return kargs;
}
@@ -421,6 +461,7 @@ struct FmhaFwdKernel
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
@@ -465,6 +506,7 @@ struct FmhaFwdKernel
scale_s,
scale_p,
scale_o,
logits_soft_cap,
stride_q,
stride_k,
stride_v,
@@ -512,6 +554,7 @@ struct FmhaFwdKernel
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
@@ -556,6 +599,7 @@ struct FmhaFwdKernel
scale_s,
scale_p,
scale_o,
logits_soft_cap,
stride_q,
stride_k,
stride_v,
@@ -603,6 +647,7 @@ struct FmhaFwdKernel
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
@@ -652,6 +697,7 @@ struct FmhaFwdKernel
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout
{}, // placeholder for logits_soft_cap
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
@@ -703,6 +749,10 @@ struct FmhaFwdKernel
kargs.nhead_stride_randval = nhead_stride_randval;
kargs.is_store_randval = s_randval;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
}
return kargs;
}
@@ -727,6 +777,7 @@ struct FmhaFwdKernel
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
@@ -765,6 +816,7 @@ struct FmhaFwdKernel
scale_s,
scale_p,
scale_o,
logits_soft_cap,
stride_q,
stride_k,
stride_v,
@@ -806,6 +858,7 @@ struct FmhaFwdKernel
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
@@ -844,6 +897,7 @@ struct FmhaFwdKernel
scale_s,
scale_p,
scale_o,
logits_soft_cap,
stride_q,
stride_k,
stride_v,
@@ -1307,6 +1361,21 @@ struct FmhaFwdKernel
}
}();
AttentionVariant variant;
const auto variant_params = [&] {
if constexpr(kHasLogitsSoftCap)
{
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
}
else
{
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
}
}();
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
@@ -1328,6 +1397,9 @@ struct FmhaFwdKernel
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}
@@ -1342,6 +1414,9 @@ struct FmhaFwdKernel
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}

View File

@@ -6,6 +6,8 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <string>
#include <type_traits>
@@ -43,14 +45,15 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static_assert(!kMergeNumHeadGroupsSeqLenQ ||
@@ -95,7 +98,7 @@ struct FmhaFwdSplitKVKernel
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
#undef _SS_
@@ -150,6 +153,28 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t split_stride_o_acc;
};
struct LogitsSoftCapKargs
{
LogitsSoftCapKargs() = default;
void init_logits_soft_cap(float logits_soft_cap_)
{
if(0 < logits_soft_cap_)
{
logits_soft_cap = logits_soft_cap_;
logits_soft_cap_rcp = 1.f / logits_soft_cap;
}
else
{
logits_soft_cap = 0.f;
logits_soft_cap_rcp = 0.f;
}
}
float logits_soft_cap;
float logits_soft_cap_rcp;
};
struct CommonBiasKargs
{
const void* bias_ptr = nullptr;
@@ -207,7 +232,8 @@ struct FmhaFwdSplitKVKernel
EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>
std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>,
std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<3>>
{
const int32_t* seqlen_k_ptr;
@@ -229,7 +255,8 @@ struct FmhaFwdSplitKVKernel
EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>
std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>,
std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<4>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -243,6 +270,13 @@ struct FmhaFwdSplitKVKernel
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
struct BlockIndices
{
ck_tile::index_t batch_idx;
ck_tile::index_t qo_head_idx;
ck_tile::index_t kv_head_idx;
};
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
@@ -268,6 +302,7 @@ struct FmhaFwdSplitKVKernel
const void* cache_batch_idx,
float scale_s,
float scale_p,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
@@ -324,6 +359,7 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for mask
{}, // placeholder for fp8_static_quant args
{}, // placeholder for paged-block table or cache_batch_idx
{}, // placeholder for logits_soft_cap
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_q,
batch_stride_k,
@@ -363,6 +399,10 @@ struct FmhaFwdSplitKVKernel
{
kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
}
return kargs;
}
@@ -392,6 +432,7 @@ struct FmhaFwdSplitKVKernel
bool is_gappy,
float scale_s,
float scale_p,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
@@ -444,6 +485,7 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for mask
{}, // placeholder for fp8_static_quant args
{}, // placeholder for paged-block table
{}, // placeholder for logits_soft_cap
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
@@ -478,6 +520,10 @@ struct FmhaFwdSplitKVKernel
kargs.page_block_size = page_block_size;
kargs.is_gappy = is_gappy;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
}
return kargs;
}
@@ -968,6 +1014,21 @@ struct FmhaFwdSplitKVKernel
}
}();
AttentionVariant variant;
const auto variant_params = [&] {
if constexpr(kHasLogitsSoftCap)
{
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
}
else
{
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
}
}();
BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
auto o_acc_tile = [&, i_split_ = i_split]() {
if constexpr(kDoFp8StaticQuant)
{
@@ -991,6 +1052,9 @@ struct FmhaFwdSplitKVKernel
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
}
@@ -1008,6 +1072,9 @@ struct FmhaFwdSplitKVKernel
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
}

View File

@@ -0,0 +1,900 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_,
typename Policy_ = BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy>
struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
Problem::kPadHeadDimV == true);
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
#endif
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)
{
return 1;
}
if constexpr(kQKHeaddim <= 32)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 64)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2;
else
return 3;
}
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 192)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
else
{
return 1;
};
}
}();
static constexpr const char* name = "qr_async";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& /*k_element_func*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
DropoutType& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
// K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumKVLdsBuffers>{});
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
q_dram_window.init_raw();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){};
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
{
if constexpr(kStoreLSE)
{
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return o_acc;
}
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
auto k_coord = k_dist.calculate_index();
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
statically_indexed_array<index_t, NRepeat> k_offsets;
static_for<0, NRepeat, 1>{}([&](auto n0) {
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
});
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
k_dist,
k_offsets); // K DRAM tile window for
k_dram_window.init_raw();
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = [&]() {
if constexpr(kPadSeqLenK &&
(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)))
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
auto v_coord = v_dist.calculate_index();
const auto VPageIndexDim = I1;
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
statically_indexed_array<index_t, V_KRepeat> v_offsets;
(void)stride_k;
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
});
auto v_dram_window =
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
v_dist,
v_offsets,
VPageIndexDim);
// prefetch K tile
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
// main loop
do
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_of_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence();
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
});
v_dram_window.update_page_idx(v_offsets);
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(
s_acc,
get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
}
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform =
[&variant, &variant_params, &block_indices](auto& x) {
x = variant.LogitsTransform(variant_params,
variant.QueryTransform(variant_params, x),
block_indices.batch_idx,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
};
#if !CK_TILE_FMHA_FWD_FAST_EXP2
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
apply_logits_transform(s_acc.thread_buf_[i]);
}
#else
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
apply_logits_transform(s_acc.thread_buf_[i]);
}
#endif
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] =
page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v;
});
v_dram_window.update_page_idx(v_offsets);
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
if constexpr(kHasLogitsSoftCap)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
if constexpr(kHasDropout)
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
const auto p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
else
return cast_tile<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
}();
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 +
v_coord[VPageIndexDim] + k0.value] *
stride_v;
});
v_dram_window.update_page_idx(v_offsets);
}
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
}
if constexpr(i_k1 < k1_loops - 1)
move_tile_window(v_dram_window, {0, kK1});
});
}
i_total_loops++;
if(i_total_loops < num_total_loop)
{
page_idx += kN0;
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
static_for<0, NRepeat, 1>{}([&](auto n0) {
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
});
k_dram_window.update_page_idx(k_offsets);
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
}
// tail
{
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
} while(i_total_loops < num_total_loop);
// store lse
if constexpr(kStoreLSE)
{
auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
}
else
{
if constexpr(kHasLogitsSoftCap)
{
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
}
else
{
lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]);
}
}
#else
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
#endif
});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k,
stride_v,
dropout);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,18 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ true,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>;
} // namespace ck_tile

View File

@@ -27,6 +27,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
@@ -46,15 +47,21 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -128,7 +135,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
@@ -150,6 +159,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
@@ -453,9 +465,34 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform =
[&variant, &variant_params, &block_indices](auto& x) {
x = variant.LogitsTransform(variant_params,
variant.QueryTransform(variant_params, x),
block_indices.batch_idx,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
};
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
apply_logits_transform(s_acc.thread_buf_[i]);
}
#else
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
apply_logits_transform(s_acc.thread_buf_[i]);
}
#endif
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
@@ -574,7 +611,14 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
}
}
#else
p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx]));
@@ -603,8 +647,15 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}
}();
#else
@@ -711,7 +762,14 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
}
else
{
lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
if constexpr(kHasLogitsSoftCap)
{
lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
}
else
{
lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
}
}
#else
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
@@ -757,7 +815,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
@@ -771,6 +831,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
@@ -794,6 +857,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
}

View File

@@ -26,6 +26,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
@@ -45,15 +46,21 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -127,7 +134,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
@@ -149,6 +158,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
@@ -401,9 +413,28 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform =
[&variant, &variant_params, &block_indices](auto& x) {
x = variant.LogitsTransform(variant_params,
variant.QueryTransform(variant_params, x),
block_indices.batch_idx,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
};
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(apply_logits_transform, s_acc);
#else
tile_elementwise_inout(apply_logits_transform, s_acc);
#endif
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
@@ -497,7 +528,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
@@ -522,8 +560,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}
}();
#else
@@ -620,7 +666,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
else
{
lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
if constexpr(kHasLogitsSoftCap)
{
lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
}
else
{
lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
}
}
#else
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
@@ -662,7 +715,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
@@ -676,6 +731,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
@@ -699,6 +757,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
}

View File

@@ -20,6 +20,7 @@ template <typename QDataType_,
typename ODataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename AttentionVariant_,
typename FmhaMask_,
typename Traits_>
struct BlockFmhaPipelineProblem
@@ -36,6 +37,7 @@ struct BlockFmhaPipelineProblem
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using AttentionVariant = remove_cvref_t<AttentionVariant_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
@@ -50,6 +52,7 @@ struct BlockFmhaPipelineProblem
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;
@@ -69,6 +72,7 @@ template <typename QDataType_,
typename ODataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename AttentionVariant_,
typename FmhaMask_,
typename Traits_>
struct BlockFmhaFwdSplitKVPipelineProblem
@@ -84,6 +88,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using AttentionVariant = remove_cvref_t<AttentionVariant_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
@@ -98,6 +103,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;

View File

@@ -5,8 +5,8 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -28,6 +28,7 @@ struct BlockFmhaPipelineQRKSVS
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
@@ -47,14 +48,20 @@ struct BlockFmhaPipelineQRKSVS
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -101,7 +108,7 @@ struct BlockFmhaPipelineQRKSVS
else
{
return 1;
};
}
}
}();
@@ -128,7 +135,9 @@ struct BlockFmhaPipelineQRKSVS
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
@@ -147,6 +156,9 @@ struct BlockFmhaPipelineQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout) const
{
@@ -380,9 +392,28 @@ struct BlockFmhaPipelineQRKSVS
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform =
[&variant, &variant_params, &block_indices](auto& x) {
x = variant.LogitsTransform(variant_params,
variant.QueryTransform(variant_params, x),
block_indices.batch_idx,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
};
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(apply_logits_transform, s_acc);
#else
tile_elementwise_inout(apply_logits_transform, s_acc);
#endif
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
@@ -398,7 +429,12 @@ struct BlockFmhaPipelineQRKSVS
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
@@ -450,7 +486,14 @@ struct BlockFmhaPipelineQRKSVS
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
@@ -475,8 +518,16 @@ struct BlockFmhaPipelineQRKSVS
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}
}();
#else
@@ -574,7 +625,14 @@ struct BlockFmhaPipelineQRKSVS
}
else
{
lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
if constexpr(kHasLogitsSoftCap)
{
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
}
else
{
lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
}
}
#else
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
@@ -614,7 +672,9 @@ struct BlockFmhaPipelineQRKSVS
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
@@ -625,6 +685,9 @@ struct BlockFmhaPipelineQRKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout) const
{
@@ -645,6 +708,9 @@ struct BlockFmhaPipelineQRKSVS
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}

View File

@@ -29,6 +29,7 @@ struct BlockFmhaPipelineQRKSVSAsync
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
@@ -53,13 +54,19 @@ struct BlockFmhaPipelineQRKSVSAsync
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
Problem::kPadHeadDimV == true);
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -153,7 +160,9 @@ struct BlockFmhaPipelineQRKSVSAsync
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
@@ -172,6 +181,9 @@ struct BlockFmhaPipelineQRKSVSAsync
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout) const
{
@@ -435,9 +447,34 @@ struct BlockFmhaPipelineQRKSVSAsync
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform =
[&variant, &variant_params, &block_indices](auto& x) {
x = variant.LogitsTransform(variant_params,
variant.QueryTransform(variant_params, x),
block_indices.batch_idx,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
};
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
apply_logits_transform(s_acc.thread_buf_[i]);
}
#else
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
apply_logits_transform(s_acc.thread_buf_[i]);
}
#endif
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
@@ -454,7 +491,12 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
@@ -543,7 +585,14 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
@@ -568,8 +617,15 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}
}();
#else
@@ -695,7 +751,14 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else
{
lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]);
if constexpr(kHasLogitsSoftCap)
{
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
}
else
{
lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]);
}
}
#else
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
@@ -735,7 +798,9 @@ struct BlockFmhaPipelineQRKSVSAsync
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
@@ -746,6 +811,9 @@ struct BlockFmhaPipelineQRKSVSAsync
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout) const
{
@@ -766,6 +834,9 @@ struct BlockFmhaPipelineQRKSVSAsync
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}

View File

@@ -7,6 +7,7 @@
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -27,6 +28,7 @@ struct BlockFmhaPipelineQSKSVS
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
@@ -44,14 +46,21 @@ struct BlockFmhaPipelineQSKSVS
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
@@ -95,7 +104,9 @@ struct BlockFmhaPipelineQSKSVS
return 1;
}
else
{
return 1;
}
}
}();
@@ -122,7 +133,9 @@ struct BlockFmhaPipelineQSKSVS
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
@@ -141,6 +154,9 @@ struct BlockFmhaPipelineQSKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& /* unused_dropout */) const
{
@@ -380,9 +396,28 @@ struct BlockFmhaPipelineQSKSVS
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform =
[&variant, &variant_params, &block_indices](auto& x) {
x = variant.LogitsTransform(variant_params,
variant.QueryTransform(variant_params, x),
block_indices.batch_idx,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
};
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(apply_logits_transform, s_acc);
#else
tile_elementwise_inout(apply_logits_transform, s_acc);
#endif
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
@@ -398,7 +433,12 @@ struct BlockFmhaPipelineQSKSVS
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
@@ -450,7 +490,14 @@ struct BlockFmhaPipelineQSKSVS
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
@@ -481,8 +528,16 @@ struct BlockFmhaPipelineQSKSVS
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}
}();
#else
@@ -571,7 +626,14 @@ struct BlockFmhaPipelineQSKSVS
}
else
{
lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
if constexpr(kHasLogitsSoftCap)
{
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
}
else
{
lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
}
}
#else
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
@@ -611,7 +673,9 @@ struct BlockFmhaPipelineQSKSVS
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
@@ -622,6 +686,9 @@ struct BlockFmhaPipelineQSKSVS
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout) const
{
@@ -642,6 +709,9 @@ struct BlockFmhaPipelineQSKSVS
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}

View File

@@ -13,6 +13,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
bool kHasLogitsSoftCap_,
BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kStoreLSE_,
@@ -25,6 +26,7 @@ struct TileFmhaTraits
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_;
@@ -37,6 +39,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
bool kHasLogitsSoftCap_,
BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
@@ -51,6 +54,7 @@ struct TileFmhaFwdSplitKVTraits
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_;