Fix flash attn mask bug (#733)

* add check input parameter

* add instance for vector load = 1

* move gerneral instance to first pos

* fix read bias code

* regular code for bias load

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>

[ROCm/composable_kernel commit: 0ede66de54]
This commit is contained in:
ltqin
2023-06-12 21:35:31 +08:00
committed by GitHub
parent 9499f4b51b
commit 8c5f5f1293
7 changed files with 86 additions and 79 deletions

View File

@@ -121,7 +121,8 @@ using DeviceOpInstance =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
MaskingSpec, // MaskingSpecialization
1>;
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,