[CK_TILE] FMHA BWD Decode Pipeline (#2643)

* Fix distr

* Duplicate block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr

* decode 16x16 o2
This commit is contained in:
Yi DING
2025-08-12 17:02:52 +08:00
committed by GitHub
parent 352f87e684
commit 8e1eb0c1ee
11 changed files with 1051 additions and 165 deletions

View File

@@ -127,5 +127,7 @@ PIPELINE_ENUM_MAP = {
BOOL_MAP = {
"t" : "true",
"f" : "false"
"f" : "false",
True : "true",
False : "false",
}

View File

@@ -7,7 +7,7 @@ from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Literal
from typing import List, Tuple, Dict, Literal, Any
from collections import defaultdict
from codegen.cmake_config import *
@@ -31,6 +31,7 @@ using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
using fmha_warp_tile2_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, ck_tile::min({F_wk0}, {F_bk4})>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
@@ -46,7 +47,8 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx}
fmha_block_warps1_{F_idx},
fmha_warp_tile1_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile0_{F_idx}>;
fmha_warp_tile2_{F_idx},
{F_maxq}>;
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<false, /* kPadSeqLenQ */
false, /* kPadSeqLenK */
@@ -100,10 +102,17 @@ using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
false,
{F_dvpad}>>;
using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType,
false,
{F_dpad}>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
fmha_bwd_dv_epilogue_{F_idx},
fmha_bwd_dq_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dtype},
@@ -115,7 +124,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dpad},
{F_dvpad},
{F_deterministic},
{F_trload}>;
{F_trload},
{F_maxq}>;
#include <iostream>
@@ -144,6 +154,13 @@ void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_co
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::kMaxSeqLenQ;
}}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
{{
@@ -159,13 +176,25 @@ FMHA_BWD_API="""
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
if constexpr (!std::is_same_v<convert_dq_trait_, void>)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
}}
else
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
);
}}
}}
template <>
@@ -177,28 +206,25 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
}}
"""
FMHA_BWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{
{F_body}
}}
"""
def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) -> str:
lines = [
f"{'if' if if_ == 0 else 'else if'}({F_cond})",
"{",
*[' ' + line for line in F_body.split('\n') if line.strip() != ''],
"}",
]
return '\n'.join(' ' * indent + line for line in lines) + '\n'
FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_body}
}}
"""
FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_body}
}}
"""
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
}}
FMHA_BWD_API_INNER_DISPATCH="""
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
return r;
}}
"""
# M0 size for 1d kernels (dot/convert)
@@ -237,11 +263,13 @@ class FmhaBwdDQDKDVTileSize:
F_wn1 : int # warp size along n in gemm1/gemm3
F_wk1 : int # warp size along k in gemm1/gemm3
F_occupancy : int # occupancy
max_seq_q : int = 0
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}_maxq{self.max_seq_q}"
@dataclass(frozen=True)
class FmhaBwdDQDKDVKernel:
@@ -301,6 +329,7 @@ class FmhaBwdDQDKDVKernel:
F_mode = MODE_MAP[self.F_mode],
F_deterministic = BOOL_MAP[self.F_deterministic],
F_trload = BOOL_MAP[self.F_trload],
F_maxq = self.F_tile.max_seq_q
)
@property
@@ -345,21 +374,23 @@ class FmhaBwdDQDKDVKernel:
# TODO: design a more practical way to do it
# this is current supported tile size.
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str, tr_load: str) -> Optional[dict]:
def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f':
return {
'32' : FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
'64' : FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
'128' : FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# '160' : FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
'256' : FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
}
return [
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
]
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't':
return {
'128' : FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
}
return [
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
]
else:
return None
return []
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
@@ -537,6 +568,7 @@ class FmhaBwdConvertQGradKernel:
F_mode : str # value from MODE_MAP
F_occupancy : int #
F_deterministic : str #
disabled : bool # sometimes this kernel is not used
@property
def template(self) -> str:
@@ -590,7 +622,7 @@ class FmhaBwdApiTrait:
dvpad : str
deterministic : str
mask_impl : str
tr_load : bool
tr_load : str
@property
def bm0(self) -> int:
@@ -650,17 +682,17 @@ class FmhaBwdApiTrait:
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
F_deterministic=self.deterministic)
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
class FmhaBwdApiPool:
def __init__(self, mask_impl):
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
self.mask_impl = mask_impl
def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None:
# TODO: do we need to check duplication?
self.dq_dk_dv_pool[trait.tr_load][trait.dtype][trait.hdim].append(copy.copy(trait))
self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][trait.hdim].append(copy.copy(trait))
@staticmethod
def if_(i: int) -> str:
@@ -675,40 +707,68 @@ class FmhaBwdApiPool:
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load])
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled])
i += 1
return inners
@staticmethod
def trload_sort_key(tf):
return 0 if tf == 't' else 1 # sort 't' before 'f'
@staticmethod
def max_seq_q_sort_key(max_seq_q):
return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end
@staticmethod
def max_seq_q_cond(max_seq_q: int) -> str:
if max_seq_q == 0:
return 'true /* no seqlen_q limit */'
else:
return f'a.seqlen_q <= {max_seq_q}'
@staticmethod
def dtype_cond(dtype: str) -> str:
return f't.data_type.compare("{dtype}") == 0'
@staticmethod
def hdim_cond(hdim: int) -> str:
return f't.hdim_q <= {hdim} && t.hdim_v <= {hdim}'
@property
def api(self) -> str:
tr_load_cond_map = {
"t": "has_load_tr",
"f": "true"
"f": "true /* no trload requirement */"
}
per_tr_load = ''
for tr_load in ["t", "f"]:
per_dtypes = ''
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load]):
per_hdim_case = ''
for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][dtype]):
traits = self.dq_dk_dv_pool[tr_load][dtype][hdim]
inners = self._api_innders(traits)
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(k), F_hdim=hdim, F_body=inners)
per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(j), F_dtype=dtype, F_body=per_hdim_case)
per_tr_load += FMHA_BWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_body=per_dtypes)
for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key):
per_max_seq_q = ''
for max_seq_q in sorted(self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key):
per_dtypes = ''
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]):
per_hdim_case = ''
for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]):
traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim]
inners = self._api_innders(traits)
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(if_=k, F_cond=self.hdim_cond(hdim), F_body=inners)
per_dtypes += FMHA_BWD_API_COND_STATEMENT(if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case)
per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes)
per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
return result.replace('\n\n', '\n')
def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]:
if filter_list == '':
filter_list = '*@*@*'
filter_list = filter_list.split('@')
filter_list.extend(['*'] * (3 - len(filter_list)))
filter_dot_do_o = filter_list[0]
filter_convert_dq = filter_list[1]
filter_dq_dk_dv = filter_list[2]
filters = filter_list.split('@')
filters.extend(['*'] * (3 - len(filters)))
filter_dot_do_o = filters[0]
filter_convert_dq = filters[1]
filter_dq_dk_dv = filters[2]
# use dict as ordered set
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {}
@@ -717,14 +777,14 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
api_pool = FmhaBwdApiPool(mask_impl)
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype, tr_load)
if d is None:
continue
for hdim_str, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
tile = d[hdim_str]
hdim = int(hdim_str)
tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load)
for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize"
hdim = tile.F_bhdq
if (mode == "group") and (spad1d == "f"):
continue
if (mode == "group" or ('no' not in mask)) and tile.max_seq_q != 0:
continue
if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue
if ("wg32" in dropout):
@@ -788,7 +848,8 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
continue
gen_dot_do_o[t.dot_do_o_kernel] = True
gen_dq_dk_dv[t.dq_dk_dv_kernel] = True
gen_convert_dq[t.convert_dq_kernel] = True
if not t.convert_dq_kernel.disabled:
gen_convert_dq[t.convert_dq_kernel] = True
api_pool.register_dq_dk_dv_traits(t)
return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys())

View File

@@ -793,6 +793,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
// set to bad values to check if the kernel writes to these buffers
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
dq_buf.ToDevice(dq_host.data());
dk_buf.ToDevice(dk_host.data());
dv_buf.ToDevice(dv_host.data());
o_buf.ToDevice(o_host.data());
lse_buf.ToDevice(lse_host.data());
dq_buf.SetZero();
@@ -801,6 +809,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, "
"bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n",
fmha_traits.hdim_q,
fmha_traits.hdim_v,
fmha_traits.data_type.c_str(),
fmha_traits.is_group_mode,
static_cast<int>(fmha_traits.mask_type),
static_cast<int>(fmha_traits.bias_type),
fmha_traits.has_dbias,
fmha_traits.has_dropout,
fmha_traits.is_store_randval,
fmha_traits.is_deterministic);
fflush(stdout);
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
dq_buf.FromDevice(dq_host.data());

View File

@@ -156,6 +156,12 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0;
const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr;
const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq;
const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq;
const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq;
// create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{
@@ -170,7 +176,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
@@ -185,7 +191,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
stride_dq,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
@@ -196,7 +202,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
nhead_stride_dq,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
@@ -220,7 +226,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
@@ -234,7 +240,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
stride_dq,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
@@ -245,7 +251,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
nhead_stride_dq,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
@@ -256,7 +262,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
args.batch_stride_dq_acc,
batch_stride_dq,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
@@ -365,20 +371,10 @@ template <ck_tile::index_t HDim_,
bool kPadD_,
bool kPadDv_,
bool kIsDeterministic_,
bool kUseTrLoad_>
bool kUseTrLoad_,
ck_tile::index_t MaxSeqLenQ_>
struct fmha_bwd_dq_dk_dv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
};
template <typename Traits_>
@@ -389,6 +385,8 @@ void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dq_dk_dv_get_name_();
template <typename Traits_>
int fmha_bwd_dq_dk_dv_maxq_();
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
struct fmha_bwd_dot_do_o_traits_