Adjust in GetNumPrefetchV()

This commit is contained in:
Qianfeng Zhang
2025-12-15 15:02:15 +00:00
parent 370d386427
commit d281c519f3

View File

@@ -26,12 +26,14 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumPrefetchV()
{
constexpr index_t n0_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kN0Sub;
constexpr index_t k1_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kK1;
// usually kN0 is 128, kK1 is 32/16
static_assert(k1_loops >= 2, "Check failed!");
return 2;
if constexpr(n0_loops >= 4 && k1_loops >= 6)
return 3;
if constexpr(k1_loops >= 4)
return 2;
return 1;
};
template <typename Problem>