mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] FMHA Reduce register spilling in fwd with dropout (workaround for CI failures with clang-22) (#3221)
* Use vectorized stores for dropout randvals With no kPadSeqLenK the kernel uses 2 buffer_store_dwordx2 instead of 16 buffer_store_byte. This requires less registers and reduces spilling. * Calculate dropout randvals for storing and applying only once Even though it may add a small overhead when storing is not required, it uses significantly less registers and hence no spilling.
This commit is contained in:
@@ -1005,7 +1005,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
rand_val_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
|
||||
make_tuple(kargs.stride_randval, 1),
|
||||
number<1>{},
|
||||
number<FmhaPipeline::kAlignmentRandVal>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(randval_dram_naive,
|
||||
|
||||
@@ -1450,7 +1450,7 @@ struct FmhaFwdKernel
|
||||
rand_val_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
|
||||
make_tuple(kargs.stride_randval, 1),
|
||||
number<1>{},
|
||||
number<FmhaPipeline::kAlignmentRandVal>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(randval_dram_naive,
|
||||
|
||||
Reference in New Issue
Block a user