enable hotloop

This commit is contained in:
root
2025-08-21 09:46:52 -05:00
parent c378e9bdf8
commit 65989e940c
4 changed files with 39 additions and 19 deletions

View File

@@ -118,7 +118,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr int MXFP4PackedSize = 2;
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * MXFP4PackedSize;
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
? DsReadPreload
: MIterPerWarp * KIterPerWarp;
@@ -127,8 +129,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
static constexpr int ContinuousScaleNPerThread = Problem::ContinuousScaleNPerThread;
static constexpr int ContinuousScaleKPerThread = Problem::ContinuousScaleKPerThread;
static constexpr int MXFP4PackedSize = 2;
static constexpr int ScaleKFlatPerWarp =
ContinuousScaleNPerThread * ContinuousScaleKPerThread * get_warp_size();
@@ -165,9 +165,14 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
static constexpr index_t Aload_num_perK = dswrite_num_perK;
static constexpr index_t Aload_rep = dswrite_rep;
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
static constexpr index_t ScaleBload_K1 = ContinuousScaleNPerThread * ContinuousScaleKPerThread;
static constexpr index_t ScaleBload_num =
kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 /
WaveSize; // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
@@ -389,6 +394,10 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
? Aload_rep
: 0;
}
if((kIter % KPerScaleLoad == 0) && (mIter == 0))
{
load_perM = load_perM + 1;
}
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
}
@@ -875,7 +884,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
// HotLoopScheduler();
HotLoopScheduler();
// Next K
@@ -978,7 +987,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
// HotLoopScheduler();
HotLoopScheduler();
iCounter--;
}
@@ -1079,7 +1088,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
// Last2ndHotLoopScheduler();
Last2ndHotLoopScheduler();
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
@@ -1125,7 +1134,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
}
});
});
// LastHotLoopScheduler();
LastHotLoopScheduler();
}
else if constexpr(TailNum == TailNumber::Odd)
{
@@ -1174,7 +1183,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
}
});
});
// LastHotLoopScheduler();
LastHotLoopScheduler();
}
return c_block_tile;