Fix flash attn mask bug (#733)

* add check input parameter

* add instance for vector load = 1

* move gerneral instance to first pos

* fix read bias code

* regular code for bias load

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
ltqin
2023-06-12 21:35:31 +08:00
committed by GitHub
parent 016ebaa7f3
commit 0ede66de54
7 changed files with 86 additions and 79 deletions

View File

@@ -197,7 +197,8 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default>
int D0sTransferSrcScalarPerVector = 4,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
@@ -438,7 +439,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
D0sTransferSrcScalarPerVector>;
// Argument
// FIXME: constness
@@ -530,6 +532,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
// D0 pointer
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
// for check
d0s_nl_ns_lengths_strides_[i].push_back(
acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]);
d0s_nl_ns_lengths_strides_[i].push_back(
acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]);
});
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
@@ -608,6 +615,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_;
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
@@ -772,6 +780,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
{
return false;
}
for(int i = 0; i < NumD0Tensor; i++)
{
if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 &&
arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0)
{
std::cout << "first" << std::endl;
return false;
}
if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1)
{
std::cout << "second" << std::endl;
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,