mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Fix flash attn mask bug (#733)
* add check input parameter * add instance for vector load = 1 * move gerneral instance to first pos * fix read bias code * regular code for bias load --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -197,7 +197,8 @@ template <index_t NumDimG,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
MaskingSpecialization MaskingSpec,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
int D0sTransferSrcScalarPerVector = 4,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
: public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
|
||||
NumDimM,
|
||||
@@ -438,7 +439,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
Transform::matrix_padder.PadN,
|
||||
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
|
||||
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
|
||||
D0sTransferSrcScalarPerVector>;
|
||||
|
||||
// Argument
|
||||
// FIXME: constness
|
||||
@@ -530,6 +532,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
// D0 pointer
|
||||
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
|
||||
// for check
|
||||
d0s_nl_ns_lengths_strides_[i].push_back(
|
||||
acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]);
|
||||
d0s_nl_ns_lengths_strides_[i].push_back(
|
||||
acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]);
|
||||
});
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
@@ -608,6 +615,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
std::vector<index_t> b_nz_kz_strides_;
|
||||
std::vector<index_t> b1_nz_kz_strides_;
|
||||
std::vector<index_t> c_mz_gemm1nz_strides_;
|
||||
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides_;
|
||||
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
@@ -772,6 +780,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for(int i = 0; i < NumD0Tensor; i++)
|
||||
{
|
||||
if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 &&
|
||||
arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
std::cout << "first" << std::endl;
|
||||
return false;
|
||||
}
|
||||
if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1)
|
||||
{
|
||||
std::cout << "second" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
|
||||
Reference in New Issue
Block a user