replace hard-coded WaveSize with call to get_warp_size()

This commit is contained in:
Philip Maybank
2025-07-29 11:55:27 +01:00
parent 1ca839925a
commit 7dbe5e7d37

View File

@@ -58,7 +58,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 32;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNumM = BlockGemm::MWarp;
constexpr index_t WaveNumN = BlockGemm::NWarp;