mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
[CK_TILE] wa prec, remove sgpr offset for inline asm (#1356)
* wa prec, remove sgpr offset for inline asm
* macro for set tile
* ignore unused param if no kernel instances in host API
* fix more prec issue
* cache buffer resource
* fix
* support pre-nop
* clear tile by vector type members
* add workaround to reduce scratch memory
* conditionally enable workaround code
* enable workaround start from certain build version
* fallback set_tile() implementation from certain build version
* undo template argument changes
* put dummy asm in load_raw()
* fix comments, refactor s_nop inside buffer_load
---------
Co-authored-by: PoYen, Chen <PoYen.Chen@amd.com>
[ROCm/composable_kernel commit: 8182976c37]
This commit is contained in:
@@ -271,7 +271,9 @@ class FmhaBwdApiPool:
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
# GEMM0: Q@K=S^T
|
||||
|
||||
@@ -278,6 +278,9 @@ class FmhaFwdApiPool:
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool:
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user