mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
Only generate specific hdim (#2120)
This commit is contained in:
@@ -866,9 +866,11 @@ def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autoge
|
||||
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None:
|
||||
def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (3 - len(filter_list)))
|
||||
# TODO
|
||||
assert optdim_list == [-1]
|
||||
|
||||
kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt)
|
||||
for kernel in kernels:
|
||||
@@ -881,9 +883,11 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> Non
|
||||
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
|
||||
write_bwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None:
|
||||
def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (3 - len(filter_list)))
|
||||
# TODO
|
||||
assert optdim_list == [-1]
|
||||
|
||||
with file_path.open('a') as f:
|
||||
kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt)
|
||||
|
||||
@@ -429,7 +429,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
@@ -507,6 +507,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
@@ -557,15 +560,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : str, receipt, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
|
||||
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : str, receipt, mask_impl) -> None:
|
||||
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
|
||||
@@ -343,13 +343,15 @@ def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> Non
|
||||
def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
|
||||
assert optdim_list == [-1]
|
||||
api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_fwd_appendkv_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
|
||||
assert optdim_list == [-1]
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
|
||||
@@ -440,10 +440,10 @@ class FmhaFwdSplitKVCombinePipeline:
|
||||
n = f'{self.tag}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
return n
|
||||
@@ -819,9 +819,10 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -
|
||||
file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME
|
||||
file_path.write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None:
|
||||
def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (2 - len(filter_list)))
|
||||
assert optdim_list == [-1]
|
||||
|
||||
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt)
|
||||
for kernel in kernels:
|
||||
@@ -831,9 +832,10 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> Non
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_fwd_splitkv_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None:
|
||||
def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (2 - len(filter_list)))
|
||||
assert optdim_list == [-1]
|
||||
|
||||
with file_path.open('a') as f:
|
||||
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt)
|
||||
|
||||
Reference in New Issue
Block a user