[fix] align v3 gufusion pipeline

This commit is contained in:
lalala-sh
2025-04-30 02:27:39 +00:00
parent b8427b812e
commit 9c06c3817a

View File

@@ -143,7 +143,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::c_thread_desc_;
using Base::MWaves;
@@ -532,13 +531,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
// Local prefetch A1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(I0, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(I0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, 2, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
// Initialize C
@@ -858,7 +861,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
});
});
if constexpr(m0.value != (MRepeat - 1))
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
@@ -951,7 +954,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
@@ -963,6 +966,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
using Base::c_thread_desc_;
};
} // namespace ck