mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[CK_TILE] Add logits soft-capping & customization support to the FMHA forward kernel/pipelines (#2163)
* hack for cap logits * fix bug * Re-format files * Allow specifying logits_soft_cap through APIs * Support turn on/off logits_soft_cap in async pipeline * Do not generate non-verified kernels * Align receipt used in Aiter * Sync logits soft-capping across pipelines * Re-enable some hdim pipelines * fix perf * Add attention variant for logits_soft_cap * Add newline at end-of-file * Fix performance * Add comment to explain logits_soft_cap pre-processing * Unify code * Unify floating-point literal style * Use class data member to slience the compilation error * [CK_TILE] Update attention customizaton interface: add LogitsMask() (#2133) * Send 'mask' along with variant params to the LogitsMask() * Send block indices to the variant * Add indices parameters in variant interface * Fix fmha bwd codegen error * Allow switch logits_soft_cap impl * Eliminate register spills * Fix compilation errors * Fix wrong LSE * Fix LSE for splitkv kernel * Sync splitkv pipeline changes * Add batch_prefill kernel/pipeline * Fix codegen error * Undo changes in CMakeLists.txt * Merge pipeline filtering check * Use different code path if kHasLogitsSoftCap=false * Remove [[maybe_unused]] attribute * Use pre-existing compile-time flag to instantiate templates * Sync pipeline changes * Update CHANGELOG.md --------- Co-authored-by: Bernard <bernaliu@amd.com> Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
595
example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
Normal file
595
example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
Normal file
@@ -0,0 +1,595 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8" : 8,
|
||||
"bf8" : 8
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32 : 32,
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
256: 256
|
||||
}
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
"qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_batch_prefill_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag : str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
else: n += '_ndropout'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
return n
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0max = self.F_tile.F_bk0max,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_wm0 = self.F_tile.F_wm0,
|
||||
F_wn0 = self.F_tile.F_wn0,
|
||||
F_wk0 = self.F_tile.F_wk0,
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
### '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
### '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
@@ -60,6 +60,7 @@ using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
false,
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
false,
|
||||
|
||||
@@ -32,6 +32,7 @@ K0_MAX_SUBMAX_MAP = {
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
@@ -51,12 +52,16 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
@@ -73,6 +78,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
@@ -88,7 +94,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -123,9 +129,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -144,6 +150,7 @@ class FmhaFwdApiTrait:
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
@@ -157,7 +164,7 @@ class FmhaFwdApiTrait:
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
@@ -165,7 +172,7 @@ class FmhaFwdApiTrait:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
@@ -176,7 +183,7 @@ class FmhaFwdApiTrait:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
@@ -187,7 +194,7 @@ class FmhaFwdApiTrait:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
@@ -199,7 +206,7 @@ class FmhaFwdApiTrait:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
@@ -214,6 +221,7 @@ class FmhaFwdPipeline:
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
@@ -235,6 +243,9 @@ class FmhaFwdPipeline:
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
@@ -280,7 +291,7 @@ class FmhaFwdApiPool:
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
@@ -365,6 +376,7 @@ class FmhaFwdKernel:
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
@@ -399,6 +411,7 @@ class FmhaFwdKernel:
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
@@ -440,36 +453,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
@@ -497,6 +510,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
|
||||
@@ -45,6 +45,7 @@ FMHA_FWD_SPLITKV_PIPELINE_MAP = {
|
||||
|
||||
FMHA_FWD_SPLITKV_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
namespace {{
|
||||
@@ -63,6 +64,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
/*kHasBiasGrad=*/false,
|
||||
{F_lse},
|
||||
@@ -85,6 +87,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
fmha_shape,
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_trait>;
|
||||
|
||||
@@ -113,7 +116,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
}}
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
@@ -267,9 +270,9 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
// get combine kernel tile sizes
|
||||
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
|
||||
@@ -310,6 +313,7 @@ class FmhaFwdSplitKVApiTrait:
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
mask : str
|
||||
logits : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
squant : str #
|
||||
@@ -322,7 +326,7 @@ class FmhaFwdSplitKVApiTrait:
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
|
||||
f'{self.dvpad}-{self.pagedkv}'
|
||||
|
||||
@property
|
||||
@@ -380,6 +384,7 @@ class FmhaFwdSplitKVPipeline:
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_squant : str #
|
||||
@@ -401,6 +406,9 @@ class FmhaFwdSplitKVPipeline:
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
@@ -475,7 +483,7 @@ class FmhaFwdSplitKVApiPool:
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
@@ -541,6 +549,7 @@ class FmhaFwdSplitKVKernel:
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
@@ -574,6 +583,7 @@ class FmhaFwdSplitKVKernel:
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
@@ -671,32 +681,32 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
# TODO: use async pipeline when compiler is more stable
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
else:
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
if receipt == 1:
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask))
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
@@ -720,6 +730,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
k = Kernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
|
||||
Reference in New Issue
Block a user