mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
refactor fmha_bwd.py (#2546)
This commit is contained in:
@@ -7,7 +7,7 @@ from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Dict, Literal
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
@@ -204,107 +204,13 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode})
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaBwdDQDKDVApiTrait:
|
||||
pipeline : str
|
||||
# sync with fmha_bwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along k seqlen
|
||||
bhdq : int # q head_dim
|
||||
bhdv : int # v head_dim
|
||||
mask : str
|
||||
bias : str
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
deterministic : str
|
||||
|
||||
def scheck(self, spad1 : str) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.spad == 't' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} != 0'
|
||||
elif self.spad == 'f' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
|
||||
else: # self.skpad == 'f' and skpad1 == 'f'
|
||||
return f'a.seqlen_q % 64 == 0'
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.skpad == 't':
|
||||
return f'a.seqlen_k % {self.bn0} != 0'
|
||||
else:
|
||||
return f'a.seqlen_k % {self.bn0} == 0'
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0'
|
||||
else : return f'a.hdim_q % {self.bhdq} == 0'
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
|
||||
else : return f'a.hdim_v % {self.bhdv} == 0'
|
||||
|
||||
class FmhaBwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.dq_dk_dv_pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.dq_dk_dv_pool.keys():
|
||||
self.dq_dk_dv_pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys():
|
||||
self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
|
||||
traits=self.dq_dk_dv_pool[dtype][hdim]
|
||||
hdim_int = int(hdim)
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
for spad1 in ["t", "f"]:
|
||||
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
|
||||
continue
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic])
|
||||
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
# GEMM0: Q@K=S^T
|
||||
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
|
||||
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
|
||||
# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk)
|
||||
# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk)
|
||||
# Is it necessary to distinguish between K0~K4?
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdDQDKDVTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
@@ -337,7 +243,7 @@ class FmhaBwdDQDKDVTileSize:
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdDQDKDVKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
@@ -440,26 +346,6 @@ class FmhaBwdDQDKDVKernel:
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaBwdDQDKDVApiTrait:
|
||||
return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bhdq=self.F_tile.F_bhdq,
|
||||
bhdv=self.F_tile.F_bhdv,
|
||||
mask=self.F_mask,
|
||||
bias=self.F_bias,
|
||||
dbias=self.F_dbias,
|
||||
dropout=self.F_dropout,
|
||||
spad=self.F_spad,
|
||||
skpad=self.F_skpad,
|
||||
dpad=self.F_dpad,
|
||||
dvpad=self.F_dvpad,
|
||||
deterministic=self.F_deterministic
|
||||
)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size & pipeline.
|
||||
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
@@ -477,84 +363,6 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad
|
||||
# support this in future
|
||||
gen = list()
|
||||
api_pool = FmhaBwdApiPool(mask_impl)
|
||||
|
||||
for dtype in BWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
tile = d[hdim_str][0]
|
||||
ppl = d[hdim_str][1]
|
||||
hdim = int(hdim_str)
|
||||
if (mode == "group") and (spad == "f" or skpad == "f"):
|
||||
continue
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
if ("wg32" in dropout):
|
||||
continue
|
||||
if (dpad == "t" or dvpad == "t"):
|
||||
ppl = d[hdim_str][2]
|
||||
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
|
||||
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
|
||||
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
|
||||
F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
# Flash attention integration
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 3:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'bias']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
cond &= mode == 'batch'
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_bwd) integration
|
||||
elif receipt == 300:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_varlen_bwd) integration
|
||||
elif receipt == 400:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_dq_dk_dv_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
@@ -613,7 +421,7 @@ std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdOGradDotOKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
@@ -653,49 +461,6 @@ class FmhaBwdOGradDotOKernel:
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdOGradDotOKernel]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
gen = list()
|
||||
|
||||
for dtype in BWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
hdim = int(hdim_str)
|
||||
if (mode == "group" and spad == "f"):
|
||||
continue
|
||||
k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype,
|
||||
F_spad=spad, F_dvpad=dvpad, F_mode=mode,
|
||||
F_occupancy=get_occupancy(dtype, hdim))
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
# Aiter (mha_bwd) integration
|
||||
if receipt == 300:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_varlen_bwd) integration
|
||||
elif receipt == 400:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
|
||||
FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
@@ -762,7 +527,7 @@ std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdConvertQGradKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
@@ -810,92 +575,257 @@ class FmhaBwdConvertQGradKernel:
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdConvertQGradKernel]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdApiTrait:
|
||||
idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
pipeline : str
|
||||
# sync with fmha_bwd_traits<>, to generate fallback calls
|
||||
hdim : int
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
tile : FmhaBwdDQDKDVTileSize
|
||||
mask : str
|
||||
bias : str
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad : str
|
||||
spad1 : str # spad for dot/convert kernel
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
deterministic : str
|
||||
mask_impl : str
|
||||
|
||||
gen = list()
|
||||
@property
|
||||
def bm0(self) -> int:
|
||||
return self.tile.F_bm0
|
||||
@property
|
||||
def bn0(self) -> int:
|
||||
return self.tile.F_bn0
|
||||
@property
|
||||
def bhdq(self) -> int:
|
||||
return self.tile.F_bhdq
|
||||
@property
|
||||
def bhdv(self) -> int:
|
||||
return self.tile.F_bhdv
|
||||
|
||||
def scheck(self, spad1 : str) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.spad == 't' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} != 0'
|
||||
elif self.spad == 'f' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
|
||||
else: # self.skpad == 'f' and skpad1 == 'f'
|
||||
return 'a.seqlen_q % 64 == 0'
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.skpad == 't':
|
||||
return f'a.seqlen_k % {self.bn0} != 0'
|
||||
else:
|
||||
return f'a.seqlen_k % {self.bn0} == 0'
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0'
|
||||
else : return f'a.hdim_q % {self.bhdq} == 0'
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
|
||||
else : return f'a.hdim_v % {self.bhdv} == 0'
|
||||
|
||||
@property
|
||||
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1,
|
||||
F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim))
|
||||
|
||||
@property
|
||||
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
|
||||
return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile,
|
||||
F_spad=self.spad, F_skpad=self.skpad, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias,
|
||||
F_dbias=self.dbias, F_dropout=self.dropout, F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_pipeline=self.pipeline, mask_impl=self.mask_impl)
|
||||
|
||||
@property
|
||||
def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
|
||||
F_bm0=64, F_bn0=self.tile.F_bn0, F_spad=self.spad, F_dpad=self.dpad,
|
||||
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
|
||||
F_deterministic=self.deterministic)
|
||||
|
||||
class FmhaBwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.dq_dk_dv_pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.dq_dk_dv_pool.keys():
|
||||
self.dq_dk_dv_pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys():
|
||||
self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
|
||||
traits=self.dq_dk_dv_pool[dtype][hdim]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
for spad1 in ["t", "f"]:
|
||||
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
|
||||
continue
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic])
|
||||
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
def get_bwd_blobs(filter_list: str, receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]:
|
||||
if filter_list == '':
|
||||
filter_list = '*@*@*'
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend(['*'] * (3 - len(filter_list)))
|
||||
filter_dot_do_o = filter_list[0]
|
||||
filter_convert_dq = filter_list[1]
|
||||
filter_dq_dk_dv = filter_list[2]
|
||||
|
||||
# use dict as ordered set
|
||||
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {}
|
||||
gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {}
|
||||
gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {}
|
||||
api_pool = FmhaBwdApiPool(mask_impl)
|
||||
|
||||
for dtype in BWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
if d is None:
|
||||
continue
|
||||
for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
hdim = int(hdim_str)
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, spad1, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 6)):
|
||||
tile = d[hdim_str][0]
|
||||
if (mode == "group" and spad == "f"):
|
||||
ppl = d[hdim_str][1]
|
||||
hdim = int(hdim_str)
|
||||
if (mode == "group") and (spad == "f" or skpad == "f"):
|
||||
continue
|
||||
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
|
||||
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
if (spad1 == "f") and (spad == "t" or mode == "group"):
|
||||
continue
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
if ("wg32" in dropout):
|
||||
continue
|
||||
if (dpad == "t" or dvpad == "t"):
|
||||
ppl = d[hdim_str][2]
|
||||
t = FmhaBwdApiTrait(idx=0, pipeline=ppl, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad=spad, spad1=spad1, skpad=skpad, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl)
|
||||
|
||||
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
|
||||
continue
|
||||
if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv):
|
||||
continue
|
||||
if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq):
|
||||
continue
|
||||
|
||||
# Flash attention integration
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 3:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'bias']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
cond &= mode == 'batch'
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_bwd) integration
|
||||
if receipt == 300:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 300:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_varlen_bwd) integration
|
||||
elif receipt == 400:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
gen_dot_do_o[t.dot_do_o_kernel] = True
|
||||
gen_dq_dk_dv[t.dq_dk_dv_kernel] = True
|
||||
gen_convert_dq[t.convert_dq_kernel] = True
|
||||
api_pool.register_dq_dk_dv_traits(t)
|
||||
|
||||
return gen
|
||||
|
||||
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
|
||||
return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys())
|
||||
|
||||
def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (3 - len(filter_list)))
|
||||
# TODO
|
||||
assert optdim_list == [-1]
|
||||
assert optdim_list == [-1] # TODO
|
||||
|
||||
kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
|
||||
kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_convert_dq_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
|
||||
write_bwd_api(api_pool, output_dir)
|
||||
api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl)
|
||||
(output_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
|
||||
for k in kernels_dot_do_o:
|
||||
(output_dir / k.filename).write_text(k.template)
|
||||
for k in kernels_convert_dq:
|
||||
(output_dir / k.filename).write_text(k.template)
|
||||
for k in kernels_dq_dk_dv:
|
||||
(output_dir / k.filename).write_text(k.template)
|
||||
|
||||
def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (3 - len(filter_list)))
|
||||
# TODO
|
||||
assert optdim_list == [-1]
|
||||
|
||||
with file_path.open('a') as f:
|
||||
kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None:
|
||||
assert optdim_list == [-1] # TODO
|
||||
|
||||
_, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(
|
||||
filter_list, receipt, mask_impl
|
||||
)
|
||||
with file_path.open("a") as f:
|
||||
for k in kernels_dot_do_o:
|
||||
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
|
||||
for k in kernels_dq_dk_dv:
|
||||
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
|
||||
for k in kernels_convert_dq:
|
||||
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
|
||||
|
||||
Reference in New Issue
Block a user