fmha hdim256 vectorize improve (#2086)

For hdim 256, will not have vectorized buffer load when seqlen % 256 != 0 and hdim % 256 = 0; this commit tries to solve this condition.

[ROCm/composable_kernel commit: 94d47b1680]
This commit is contained in:
joyeamd
2025-04-16 09:21:04 +08:00
committed by GitHub
parent d9c9f17c3d
commit 19a9980cd5

View File

@@ -445,6 +445,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))