Apply filter to every kernel in the codgen of FMHA (#1911)

* add receipt for fwd

* Add receipt for bwd

* Use kernel name to avoid more receipt

* apply filter to every kernel
This commit is contained in:
rocking
2025-02-26 20:20:29 +08:00
committed by GitHub
parent c9bcfd755e
commit e9ee568681
5 changed files with 126 additions and 66 deletions

View File

@@ -412,13 +412,19 @@ class FmhaBwdDQDKDVKernel:
pn = pad_name()
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}'
if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_bias != 'no' :
n += f'_{self.F_bias}'
else:
n += '_nbias'
if self.F_dbias == 't' : n += '_dbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
if self.F_dropout != 'no' :
n += f'_{self.F_dropout}'
else:
n += '_ndropout'
if self.F_deterministic == 't' : n += '_deterministic'
return n
@@ -489,7 +495,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic)
if kernel_filter != None:
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# Flash attention integration
@@ -517,23 +523,19 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if not cond:
continue
# Aiter (mha_bwd) integration
elif receipt == 10:
elif receipt == 300:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
cond &= bias in ['no', 'alibi']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
cond &= deterministic == "t"
if not cond:
continue
# Aiter (mha_varlen_bwd) integration
elif receipt == 11:
elif receipt == 400:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
cond &= bias in ['no', 'alibi']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
cond &= deterministic == "t"
if not cond:
continue
api_pool.register_dq_dk_dv_traits(k.api_trait())
@@ -638,7 +640,7 @@ class FmhaBwdOGradDotOKernel:
def filename(self) -> str:
return self.name + ".cpp"
def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdOGradDotOKernel]:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
@@ -657,6 +659,21 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype,
F_spad=spad, F_dvpad=dvpad, F_mode=mode,
F_occupancy=get_occupancy(dtype, hdim))
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# Aiter (mha_bwd) integration
if receipt == 300:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
if not cond:
continue
# Aiter (mha_varlen_bwd) integration
elif receipt == 400:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
if not cond:
continue
gen.append(k)
return gen
@@ -773,7 +790,7 @@ class FmhaBwdConvertQGradKernel:
def filename(self) -> str:
return self.name + ".cpp"
def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdConvertQGradKernel]:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
# support this in future
def get_occupancy(dtype, hdim):
@@ -792,6 +809,21 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
continue
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# Aiter (mha_bwd) integration
if receipt == 300:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
if not cond:
continue
# Aiter (mha_varlen_bwd) integration
elif receipt == 400:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
if not cond:
continue
gen.append(k)
return gen
@@ -808,27 +840,33 @@ def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autoge
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
kernels = get_bwd_dot_do_o_blobs()
def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None:
filter_list = filter_list.split('@')
filter_list.extend([''] * (3 - len(filter_list)))
kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt)
for kernel in kernels:
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
kernels = get_bwd_convert_dq_blobs()
kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt)
for kernel in kernels:
write_single_bwd_convert_dq_kernel(kernel, output_dir)
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
api_pool, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl)
for kernel in kernels:
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
write_bwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None:
filter_list = filter_list.split('@')
filter_list.extend([''] * (3 - len(filter_list)))
with file_path.open('a') as f:
kernels = get_bwd_dot_do_o_blobs()
kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
kernels = get_bwd_convert_dq_blobs()
kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
_, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")

View File

@@ -233,13 +233,22 @@ class FmhaFwdPipeline:
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_bias != 'no' :
n += f'_{self.F_bias}'
else:
n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' : n += '_lse'
if self.F_dropout == 't' : n += '_dropout'
if self.F_lse == 't' :
n += '_lse'
else:
n += '_nlse'
if self.F_dropout == 't' :
n += '_dropout'
else:
n += '_ndropout'
if self.F_squant == 't' : n += '_squant'
return n
@@ -484,7 +493,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != None:
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# 2 - Flash attention integration
@@ -504,20 +513,18 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 10:
elif receipt == 100:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
cond &= mode == 'batch'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 11:
elif receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
@@ -532,13 +539,13 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
def write_blobs(output_dir : Path, kernel_filter : str, receipt, mask_impl) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
def list_blobs(file_path : Path, kernel_filter : str, receipt, mask_impl) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:

View File

@@ -323,12 +323,11 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != None:
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# 2 - Flash attention integration
# 12 - Aiter(mha_fwd_kvcache) integration
if receipt in (2, 12):
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
if not cond:

View File

@@ -397,14 +397,23 @@ class FmhaFwdSplitKVPipeline:
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_bias != 'no' :
n += f'_{self.F_bias}'
else:
n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' : n += '_lse'
if self.F_lse == 't' :
n += '_lse'
else:
n += '_nlse'
if self.F_squant == 't' : n += '_squant'
if self.F_pagedkv == 't' : n += '_pagedkv'
if self.F_pagedkv == 't' :
n += '_pagedkv'
else:
n += '_npagedkv'
return n
@dataclass
@@ -702,7 +711,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != None:
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# Flash attention integration
@@ -714,20 +723,10 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 11:
elif receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter(mha_fwd_kvcache) integration
elif receipt == 12:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
@@ -780,9 +779,15 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline)
if kernel_filter != None:
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# Aiter(mha_varlen_fwd) integration
if receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
if not cond:
continue
gen.append(k)
return gen
@@ -794,21 +799,27 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -
file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME
file_path.write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt)
def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None:
filter_list = filter_list.split('@')
filter_list.extend([''] * (2 - len(filter_list)))
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
api_pool, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl)
api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_fwd_splitkv_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None:
filter_list = filter_list.split('@')
filter_list.extend([''] * (2 - len(filter_list)))
with file_path.open('a') as f:
kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt)
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
_, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl)
_, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n")