mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 05:55:39 +00:00
[fix] align v3 gufusion pipeline
This commit is contained in:
@@ -143,7 +143,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::c_thread_desc_;
|
||||
|
||||
using Base::MWaves;
|
||||
|
||||
@@ -532,13 +531,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(I0, I0, I0, k0, I0, I0),
|
||||
a_block_buf.At(I0),
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, 2, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf.At(I0),
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
@@ -858,7 +861,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(m0.value != (MRepeat - 1))
|
||||
if constexpr(m0.value < (MRepeat - 2))
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
@@ -951,7 +954,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack>,
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
@@ -963,6 +966,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user