mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Fix flash attn mask bug (#733)
* add check input parameter * add instance for vector load = 1 * move gerneral instance to first pos * fix read bias code * regular code for bias load --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -80,7 +80,8 @@ template <typename FloatAB,
|
||||
LoopScheduler LoopSched,
|
||||
bool PadN,
|
||||
bool MaskOutUpperTriangle,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
int D0sTransferSrcScalarPerVector = 4,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
{
|
||||
static_assert(LoopSched == LoopScheduler::Default,
|
||||
@@ -621,13 +622,13 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
|
||||
I1, // NBlockID
|
||||
I1, // MRepeat
|
||||
I1, // NRepeat
|
||||
I1, // MWaveId
|
||||
I1, // NWaveId
|
||||
I1, // MPerXdl
|
||||
I1, // NGroupNum
|
||||
I1, // NInputNum
|
||||
m0, // MRepeat
|
||||
n0, // NRepeat
|
||||
m1, // MWaveId
|
||||
n1, // NWaveId
|
||||
m2, // MPerXdl
|
||||
n2, // NGroupNum
|
||||
n3, // NInputNum
|
||||
n4)); // registerNum
|
||||
|
||||
auto d0s_thread_buf = generate_tuple(
|
||||
@@ -644,9 +645,6 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
const auto wave_id = GetGemm0WaveIdx();
|
||||
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
|
||||
|
||||
constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, n2, n4));
|
||||
|
||||
auto d0s_threadwise_copy = generate_tuple(
|
||||
[&](auto i) {
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
@@ -655,10 +653,19 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
D0DataType,
|
||||
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
|
||||
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
|
||||
Sequence<I1, I1, I1, I1, I1, I1, I1, I1, I1, n4>,
|
||||
Sequence<I1, // MBlockId
|
||||
I1, // NBlockID
|
||||
m0, // MRepeat
|
||||
n0, // NRepeat
|
||||
m1, // MWaveId
|
||||
n1, // NWaveId
|
||||
m2, // MPerXdl
|
||||
n2, // NGroupNum
|
||||
n3, // NInputNum
|
||||
n4>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
|
||||
9,
|
||||
n4,
|
||||
D0sTransferSrcScalarPerVector,
|
||||
1,
|
||||
false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
make_multi_index(block_work_idx[I0], // MBlockId
|
||||
@@ -884,62 +891,35 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
// multiple d
|
||||
if constexpr(NumD0Tensor)
|
||||
{
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto mr) {
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto nr) {
|
||||
static_for<0, n2, 1>{}([&](auto groupid) {
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).Run(
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
d0s_grid_buf[i],
|
||||
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
d0s_thread_buf(i));
|
||||
});
|
||||
static_assert(NXdlPerWave == n0);
|
||||
static_assert(MXdlPerWave == m0);
|
||||
|
||||
static_for<0, n4, 1>{}([&](auto i) {
|
||||
constexpr index_t c_offset = acc0_thread_desc.CalculateOffset(
|
||||
make_tuple(mr, nr, groupid, i));
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
d0s_grid_buf[i],
|
||||
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
d0s_thread_buf(i));
|
||||
});
|
||||
static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; },
|
||||
Number<NumD0Tensor>{});
|
||||
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& {
|
||||
return d0s_thread_buf[iSrc][i];
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto) -> auto& { return acc_thread_buf(i); },
|
||||
Number<2>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto) -> auto& {
|
||||
return acc_thread_buf(Number<c_offset>{});
|
||||
},
|
||||
Number<2>{});
|
||||
|
||||
unpack2(c0de_element_op, dst_data_refs, src_data_refs);
|
||||
});
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).MoveSrcSliceWindow(
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0));
|
||||
});
|
||||
});
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).MoveSrcSliceWindow(
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0));
|
||||
});
|
||||
});
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).MoveSrcSliceWindow(
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
make_multi_index(0, 0, 1, -NXdlPerWave, 0, 0, 0, 0, 0, 0));
|
||||
});
|
||||
unpack2(c0de_element_op, dst_data_refs, src_data_refs);
|
||||
});
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).MoveSrcSliceWindow(
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
make_multi_index(0, 1, -MXdlPerWave, 0, 0, 0, 0, 0, 0, 0));
|
||||
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
|
||||
});
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user