mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[CK_TILE] Add logits soft-capping & customization support to the FMHA forward kernel/pipelines (#2163)
* hack for cap logits * fix bug * Re-format files * Allow specifying logits_soft_cap through APIs * Support turn on/off logits_soft_cap in async pipeline * Do not generate non-verified kernels * Align receipt used in Aiter * Sync logits soft-capping across pipelines * Re-enable some hdim pipelines * fix perf * Add attention variant for logits_soft_cap * Add newline at end-of-file * Fix performance * Add comment to explain logits_soft_cap pre-processing * Unify code * Unify floating-point literal style * Use class data member to slience the compilation error * [CK_TILE] Update attention customizaton interface: add LogitsMask() (#2133) * Send 'mask' along with variant params to the LogitsMask() * Send block indices to the variant * Add indices parameters in variant interface * Fix fmha bwd codegen error * Allow switch logits_soft_cap impl * Eliminate register spills * Fix compilation errors * Fix wrong LSE * Fix LSE for splitkv kernel * Sync splitkv pipeline changes * Add batch_prefill kernel/pipeline * Fix codegen error * Undo changes in CMakeLists.txt * Merge pipeline filtering check * Use different code path if kHasLogitsSoftCap=false * Remove [[maybe_unused]] attribute * Use pre-existing compile-time flag to instantiate templates * Sync pipeline changes * Update CHANGELOG.md --------- Co-authored-by: Bernard <bernaliu@amd.com> Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
@@ -114,12 +114,14 @@ LAYOUT_MAP = {
|
||||
PIPELINE_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
"qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
|
||||
595
example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
Normal file
595
example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
Normal file
@@ -0,0 +1,595 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8" : 8,
|
||||
"bf8" : 8
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32 : 32,
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
256: 256
|
||||
}
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
"qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_batch_prefill_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag : str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
else: n += '_ndropout'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
return n
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0max = self.F_tile.F_bk0max,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_wm0 = self.F_tile.F_wm0,
|
||||
F_wn0 = self.F_tile.F_wn0,
|
||||
F_wk0 = self.F_tile.F_wk0,
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
### '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
### '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
@@ -60,6 +60,7 @@ using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
false,
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
false,
|
||||
|
||||
@@ -32,6 +32,7 @@ K0_MAX_SUBMAX_MAP = {
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
@@ -51,12 +52,16 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
@@ -73,6 +78,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
@@ -88,7 +94,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -123,9 +129,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -144,6 +150,7 @@ class FmhaFwdApiTrait:
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
@@ -157,7 +164,7 @@ class FmhaFwdApiTrait:
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
@@ -165,7 +172,7 @@ class FmhaFwdApiTrait:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
@@ -176,7 +183,7 @@ class FmhaFwdApiTrait:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
@@ -187,7 +194,7 @@ class FmhaFwdApiTrait:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
@@ -199,7 +206,7 @@ class FmhaFwdApiTrait:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
@@ -214,6 +221,7 @@ class FmhaFwdPipeline:
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
@@ -235,6 +243,9 @@ class FmhaFwdPipeline:
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
@@ -280,7 +291,7 @@ class FmhaFwdApiPool:
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
@@ -365,6 +376,7 @@ class FmhaFwdKernel:
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
@@ -399,6 +411,7 @@ class FmhaFwdKernel:
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
@@ -440,36 +453,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
@@ -497,6 +510,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
|
||||
@@ -45,6 +45,7 @@ FMHA_FWD_SPLITKV_PIPELINE_MAP = {
|
||||
|
||||
FMHA_FWD_SPLITKV_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
namespace {{
|
||||
@@ -63,6 +64,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
/*kHasBiasGrad=*/false,
|
||||
{F_lse},
|
||||
@@ -85,6 +87,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
fmha_shape,
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_trait>;
|
||||
|
||||
@@ -113,7 +116,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
}}
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
@@ -267,9 +270,9 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
// get combine kernel tile sizes
|
||||
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
|
||||
@@ -310,6 +313,7 @@ class FmhaFwdSplitKVApiTrait:
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
mask : str
|
||||
logits : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
squant : str #
|
||||
@@ -322,7 +326,7 @@ class FmhaFwdSplitKVApiTrait:
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
|
||||
f'{self.dvpad}-{self.pagedkv}'
|
||||
|
||||
@property
|
||||
@@ -380,6 +384,7 @@ class FmhaFwdSplitKVPipeline:
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_squant : str #
|
||||
@@ -401,6 +406,9 @@ class FmhaFwdSplitKVPipeline:
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
@@ -475,7 +483,7 @@ class FmhaFwdSplitKVApiPool:
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
@@ -541,6 +549,7 @@ class FmhaFwdSplitKVKernel:
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
@@ -574,6 +583,7 @@ class FmhaFwdSplitKVKernel:
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
@@ -671,32 +681,32 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
# TODO: use async pipeline when compiler is more stable
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
else:
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
if receipt == 1:
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask))
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
@@ -720,6 +730,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
k = Kernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
@@ -72,6 +73,7 @@ auto create_args(int argc, char* argv[])
|
||||
"0",
|
||||
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
|
||||
"note when squant=1, this value will be modified by range_q/k")
|
||||
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
|
||||
.insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.")
|
||||
.insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.")
|
||||
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
|
||||
@@ -416,6 +418,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(scale_s == .0f)
|
||||
scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
|
||||
|
||||
const float logits_soft_cap = arg_parser.get_float("logits_soft_cap");
|
||||
|
||||
std::string squant_str = arg_parser.get_str("squant");
|
||||
bool squant = [&]() {
|
||||
if(squant_str == "auto")
|
||||
@@ -850,6 +854,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else // fmha_fwd_traits or fmha_splitkv_traits
|
||||
{
|
||||
traits.is_group_mode = (mode == mode_enum::group);
|
||||
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
|
||||
traits.mask_type = mask.type;
|
||||
traits.bias_type = bias.type;
|
||||
traits.has_lse = lse;
|
||||
@@ -1007,6 +1012,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.scale_p = scale_p;
|
||||
args.scale_o = scale_o;
|
||||
|
||||
args.logits_soft_cap = logits_soft_cap;
|
||||
|
||||
args.stride_bias =
|
||||
(bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias);
|
||||
args.stride_o = stride_o;
|
||||
@@ -1375,6 +1382,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s));
|
||||
|
||||
if(0.f < logits_soft_cap)
|
||||
{
|
||||
ck_tile::reference_unary_elementwise<SaccDataType, SaccDataType, SaccDataType>(
|
||||
s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) {
|
||||
return ck_tile::type_convert<SaccDataType>(
|
||||
logits_soft_cap *
|
||||
std::tanhf(ck_tile::type_convert<float>(logits / logits_soft_cap)));
|
||||
});
|
||||
}
|
||||
|
||||
if(bias.type == bias_enum::elementwise_bias)
|
||||
{
|
||||
// elementwise bias
|
||||
|
||||
@@ -143,6 +143,8 @@ struct fmha_fwd_args
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
|
||||
float logits_soft_cap;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
@@ -232,6 +234,8 @@ struct fmha_fwd_splitkv_args
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
|
||||
float logits_soft_cap;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
@@ -308,6 +312,85 @@ struct fmha_fwd_appendkv_args
|
||||
ck_tile::index_t batch_stride_vnew;
|
||||
};
|
||||
|
||||
struct fmha_batch_prefill_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
void* rand_val_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
// the real seqlen_q & seqlen_k are decided by following:
|
||||
// batch mode (kvcache):
|
||||
// seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.page_block_size * (kargs.kv_indptr[b + 1] - kargs.kv_indptr[b] -
|
||||
// 1) +
|
||||
// kargs.kv_last_page_lens[b]
|
||||
// group mode (kvcache):
|
||||
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
// seqlen_k = kargs.page_block_size * (kargs.kv_indptr[b + 1] - kargs.kv_indptr[b] -
|
||||
// 1) +
|
||||
// kargs.kv_last_page_lens[b]
|
||||
const void* seqstart_q_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
// SGLang-style page table
|
||||
int32_t num_total_pages;
|
||||
void* kv_indptr;
|
||||
void* kv_page_indices;
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
void* kv_last_page_lens;
|
||||
ck_tile::index_t page_block_size;
|
||||
#endif
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
|
||||
float logits_soft_cap;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
@@ -333,6 +416,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
@@ -371,6 +455,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
@@ -443,6 +528,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.is_gappy,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
@@ -485,6 +571,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.cache_batch_idx,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
@@ -618,6 +705,117 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
@@ -630,6 +828,7 @@ template <ck_tile::index_t HDim_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
bool kHasLogitsSoftCap_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
@@ -652,6 +851,7 @@ struct fmha_fwd_traits_
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
@@ -677,6 +877,7 @@ template <ck_tile::index_t HDim_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
bool kHasLogitsSoftCap_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
@@ -699,6 +900,7 @@ struct fmha_fwd_splitkv_traits_
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
@@ -776,6 +978,9 @@ struct fmha_fwd_appendkv_traits_
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_batch_prefill_(const ck_tile::stream_config&, fmha_batch_prefill_args);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_fwd_traits
|
||||
{
|
||||
@@ -784,6 +989,7 @@ struct fmha_fwd_traits
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
bool has_logits_soft_cap;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
@@ -800,6 +1006,7 @@ struct fmha_fwd_splitkv_traits
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
bool has_logits_soft_cap;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
@@ -821,3 +1028,8 @@ struct fmha_fwd_appendkv_traits
|
||||
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
|
||||
fmha_fwd_appendkv_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
using fmha_batch_prefill_traits = fmha_fwd_traits;
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits,
|
||||
fmha_batch_prefill_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
@@ -21,8 +21,7 @@ class HandlerId(IntEnum):
|
||||
ops = []
|
||||
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
|
||||
full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
|
||||
if full_module_name not in sys.modules:
|
||||
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
|
||||
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
|
||||
unwanted_prefix = 'fmha_'
|
||||
handlers = dict(
|
||||
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
|
||||
|
||||
Reference in New Issue
Block a user