mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Fix flash attn mask bug (#733)
* add check input parameter
* add instance for vector load = 1
* move gerneral instance to first pos
* fix read bias code
* regular code for bias load
---------
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: 0ede66de54]
This commit is contained in:
@@ -121,7 +121,8 @@ using DeviceOpInstance =
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
MaskingSpec>; // MaskingSpecialization
|
||||
MaskingSpec, // MaskingSpecialization
|
||||
1>;
|
||||
|
||||
// Ref Gemm0: fp16 in, fp32 out
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
|
||||
Reference in New Issue
Block a user