refactor fmha_bwd.py (#2546)

This commit is contained in:
Illia Silin
2025-07-23 09:09:56 -07:00
committed by GitHub
parent a5fdc663c8
commit 1b6f024836

View File

@@ -7,7 +7,7 @@ from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict, Literal
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
@@ -204,107 +204,13 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode})
}}
"""
@dataclass
class FmhaBwdDQDKDVApiTrait:
pipeline : str
# sync with fmha_bwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along k seqlen
bhdq : int # q head_dim
bhdv : int # v head_dim
mask : str
bias : str
dbias : str
dropout : str
spad : str
skpad : str
dpad : str
dvpad : str
deterministic : str
def scheck(self, spad1 : str) -> str:
if self.mode == 'group':
return 'true' # always support
elif self.spad == 't' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} != 0'
elif self.spad == 'f' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
else: # self.skpad == 'f' and skpad1 == 'f'
return f'a.seqlen_q % 64 == 0'
@property
def skcheck(self) -> str:
if self.mode == 'group':
return 'true' # always support
elif self.skpad == 't':
return f'a.seqlen_k % {self.bn0} != 0'
else:
return f'a.seqlen_k % {self.bn0} == 0'
@property
def dcheck(self) -> str:
if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0'
else : return f'a.hdim_q % {self.bhdq} == 0'
@property
def dvcheck(self) -> str:
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
else : return f'a.hdim_v % {self.bhdv} == 0'
class FmhaBwdApiPool:
def __init__(self, mask_impl):
self.dq_dk_dv_pool = dict()
self.mask_impl = mask_impl
def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.dq_dk_dv_pool.keys():
self.dq_dk_dv_pool[trait.dtype] = dict()
if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys():
self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list()
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
traits=self.dq_dk_dv_pool[dtype][hdim]
hdim_int = int(hdim)
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
for spad1 in ["t", "f"]:
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
continue
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
# GEMM0: Q@K=S^T
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk)
# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk)
# Is it necessary to distinguish between K0~K4?
@dataclass
@dataclass(frozen=True)
class FmhaBwdDQDKDVTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
@@ -337,7 +243,7 @@ class FmhaBwdDQDKDVTileSize:
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
@dataclass
@dataclass(frozen=True)
class FmhaBwdDQDKDVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
@@ -440,26 +346,6 @@ class FmhaBwdDQDKDVKernel:
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaBwdDQDKDVApiTrait:
return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bhdq=self.F_tile.F_bhdq,
bhdv=self.F_tile.F_bhdv,
mask=self.F_mask,
bias=self.F_bias,
dbias=self.F_dbias,
dropout=self.F_dropout,
spad=self.F_spad,
skpad=self.F_skpad,
dpad=self.F_dpad,
dvpad=self.F_dvpad,
deterministic=self.F_deterministic
)
# TODO: design a more practical way to do it
# this is current supported tile size & pipeline.
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
@@ -477,84 +363,6 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
else:
return None
def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]:
# TODO: we don't support tuning yet, so pick up one value for pad
# support this in future
gen = list()
api_pool = FmhaBwdApiPool(mask_impl)
for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
tile = d[hdim_str][0]
ppl = d[hdim_str][1]
hdim = int(hdim_str)
if (mode == "group") and (spad == "f" or skpad == "f"):
continue
if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue
if ("wg32" in dropout):
continue
if (dpad == "t" or dvpad == "t"):
ppl = d[hdim_str][2]
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic)
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# Flash attention integration
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
if not cond:
continue
elif receipt == 3:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'bias']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
cond &= mode == 'batch'
cond &= deterministic == "f"
if not cond:
continue
# Aiter (mha_bwd) integration
elif receipt == 300:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
if not cond:
continue
# Aiter (mha_varlen_bwd) integration
elif receipt == 400:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
if not cond:
continue
# aiter::mha_bwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
api_pool.register_dq_dk_dv_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
@@ -613,7 +421,7 @@ std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
}}
"""
@dataclass
@dataclass(frozen=True)
class FmhaBwdOGradDotOKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
@@ -653,49 +461,6 @@ class FmhaBwdOGradDotOKernel:
def filename(self) -> str:
return self.name + ".cpp"
def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdOGradDotOKernel]:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
return 2
gen = list()
for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]):
hdim = int(hdim_str)
if (mode == "group" and spad == "f"):
continue
k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype,
F_spad=spad, F_dvpad=dvpad, F_mode=mode,
F_occupancy=get_occupancy(dtype, hdim))
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# Aiter (mha_bwd) integration
if receipt == 300:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
if not cond:
continue
# Aiter (mha_varlen_bwd) integration
elif receipt == 400:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
if not cond:
continue
# aiter::mha_bwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
gen.append(k)
return gen
FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
@@ -762,7 +527,7 @@ std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
}}
"""
@dataclass
@dataclass(frozen=True)
class FmhaBwdConvertQGradKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
@@ -810,92 +575,257 @@ class FmhaBwdConvertQGradKernel:
def filename(self) -> str:
return self.name + ".cpp"
def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdConvertQGradKernel]:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
return 2
@dataclass(frozen=True)
class FmhaBwdApiTrait:
idx : int # this is not a tunable, but a counter to differentiate symbol
pipeline : str
# sync with fmha_bwd_traits<>, to generate fallback calls
hdim : int
dtype : str # data type
mode : str # value from MODE_MAP
tile : FmhaBwdDQDKDVTileSize
mask : str
bias : str
dbias : str
dropout : str
spad : str
spad1 : str # spad for dot/convert kernel
skpad : str
dpad : str
dvpad : str
deterministic : str
mask_impl : str
gen = list()
@property
def bm0(self) -> int:
return self.tile.F_bm0
@property
def bn0(self) -> int:
return self.tile.F_bn0
@property
def bhdq(self) -> int:
return self.tile.F_bhdq
@property
def bhdv(self) -> int:
return self.tile.F_bhdv
def scheck(self, spad1 : str) -> str:
if self.mode == 'group':
return 'true' # always support
elif self.spad == 't' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} != 0'
elif self.spad == 'f' and spad1 == 't':
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
else: # self.skpad == 'f' and skpad1 == 'f'
return 'a.seqlen_q % 64 == 0'
@property
def skcheck(self) -> str:
if self.mode == 'group':
return 'true' # always support
elif self.skpad == 't':
return f'a.seqlen_k % {self.bn0} != 0'
else:
return f'a.seqlen_k % {self.bn0} == 0'
@property
def dcheck(self) -> str:
if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0'
else : return f'a.hdim_q % {self.bhdq} == 0'
@property
def dvcheck(self) -> str:
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
else : return f'a.hdim_v % {self.bhdv} == 0'
@property
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
return 2
return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1,
F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim))
@property
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile,
F_spad=self.spad, F_skpad=self.skpad, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias,
F_dbias=self.dbias, F_dropout=self.dropout, F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_pipeline=self.pipeline, mask_impl=self.mask_impl)
@property
def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
return 2
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
F_bm0=64, F_bn0=self.tile.F_bn0, F_spad=self.spad, F_dpad=self.dpad,
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
F_deterministic=self.deterministic)
class FmhaBwdApiPool:
def __init__(self, mask_impl):
self.dq_dk_dv_pool = dict()
self.mask_impl = mask_impl
def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.dq_dk_dv_pool.keys():
self.dq_dk_dv_pool[trait.dtype] = dict()
if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys():
self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list()
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
traits=self.dq_dk_dv_pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
for spad1 in ["t", "f"]:
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
continue
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
def get_bwd_blobs(filter_list: str, receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]:
if filter_list == '':
filter_list = '*@*@*'
filter_list = filter_list.split('@')
filter_list.extend(['*'] * (3 - len(filter_list)))
filter_dot_do_o = filter_list[0]
filter_convert_dq = filter_list[1]
filter_dq_dk_dv = filter_list[2]
# use dict as ordered set
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {}
gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {}
gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {}
api_pool = FmhaBwdApiPool(mask_impl)
for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
if d is None:
continue
for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
hdim = int(hdim_str)
for hdim_str, mode, mask, bias, dbias, dropout, spad, spad1, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 6)):
tile = d[hdim_str][0]
if (mode == "group" and spad == "f"):
ppl = d[hdim_str][1]
hdim = int(hdim_str)
if (mode == "group") and (spad == "f" or skpad == "f"):
continue
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
if (spad1 == "f") and (spad == "t" or mode == "group"):
continue
if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue
if ("wg32" in dropout):
continue
if (dpad == "t" or dvpad == "t"):
ppl = d[hdim_str][2]
t = FmhaBwdApiTrait(idx=0, pipeline=ppl, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad=spad, spad1=spad1, skpad=skpad, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl)
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
continue
if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv):
continue
if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq):
continue
# Flash attention integration
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
if not cond:
continue
elif receipt == 3:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'bias']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
cond &= mode == 'batch'
cond &= deterministic == "f"
if not cond:
continue
# Aiter (mha_bwd) integration
if receipt == 300:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
if not cond:
continue
elif receipt == 300:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
if not cond:
continue
# Aiter (mha_varlen_bwd) integration
elif receipt == 400:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
if not cond:
continue
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
if not cond:
continue
# aiter::mha_bwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
gen.append(k)
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
gen_dot_do_o[t.dot_do_o_kernel] = True
gen_dq_dk_dv[t.dq_dk_dv_kernel] = True
gen_convert_dq[t.convert_dq_kernel] = True
api_pool.register_dq_dk_dv_traits(t)
return gen
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys())
def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
filter_list = filter_list.split('@')
filter_list.extend([''] * (3 - len(filter_list)))
# TODO
assert optdim_list == [-1]
assert optdim_list == [-1] # TODO
kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt)
for kernel in kernels:
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt)
for kernel in kernels:
write_single_bwd_convert_dq_kernel(kernel, output_dir)
api_pool, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl)
for kernel in kernels:
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
write_bwd_api(api_pool, output_dir)
api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl)
(output_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
for k in kernels_dot_do_o:
(output_dir / k.filename).write_text(k.template)
for k in kernels_convert_dq:
(output_dir / k.filename).write_text(k.template)
for k in kernels_dq_dk_dv:
(output_dir / k.filename).write_text(k.template)
def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
filter_list = filter_list.split('@')
filter_list.extend([''] * (3 - len(filter_list)))
# TODO
assert optdim_list == [-1]
with file_path.open('a') as f:
kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
_, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None:
assert optdim_list == [-1] # TODO
_, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(
filter_list, receipt, mask_impl
)
with file_path.open("a") as f:
for k in kernels_dot_do_o:
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
for k in kernels_dq_dk_dv:
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
for k in kernels_convert_dq:
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")