correct a typo in tail

This commit is contained in:
Lin, Qun
2025-05-25 02:13:15 -05:00
parent 8afac88f89
commit d5e7580473

View File

@@ -450,7 +450,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
// Initialize C
c_thread_buf.Clear();
__builtin_amdgcn_sched_barrier(0);
constexpr index_t SwitchM = MRepeat/ MXdlPack > LocalPrefetchStages ? MRepeat/ MXdlPack - LocalPrefetchStages : 0;
// main body
if constexpr(HasMainLoop)
{
@@ -516,7 +516,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
if constexpr(m0.value == (MRepeat/ MXdlPack - LocalPrefetchStages))
if constexpr(m0.value == SwitchM)
{
block_sync_lds();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
@@ -524,7 +524,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
}
constexpr auto lds_buf =
m0.value >= (MRepeat/ MXdlPack - LocalPrefetchStages)
m0.value >= SwitchM
? scale_mem_buf
: scale_comp_buf;
@@ -778,14 +778,14 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
});
});
});
if constexpr(m0.value == (MRepeat - LocalPrefetchStages * MXdlPack) / MXdlPack)
if constexpr(m0.value == SwitchM)
{
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
}
constexpr auto lds_buf =
m0.value >= (MRepeat - LocalPrefetchStages * MXdlPack) / MXdlPack ? I1 : I0;
m0.value >= SwitchM ? I1 : I0;
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
@@ -893,7 +893,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
});
});
});
if constexpr(m0.value == (MRepeat - LocalPrefetchStages * MXdlPack) / MXdlPack)
if constexpr(m0.value < (MRepeat - LocalPrefetchStages * MXdlPack) / MXdlPack)
{
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
@@ -1058,7 +1058,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
});
});
});
if constexpr(m0.value == (MRepeat - LocalPrefetchStages * MXdlPack) / MXdlPack)
if constexpr(m0.value < (MRepeat - LocalPrefetchStages * MXdlPack) / MXdlPack)
{
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {