[CK TILE] Fix KIterPerInnerLoop for block gemm. (#1934)

* Fix KIterPerInnerLoop

* Fix Kpack and KPerInnerLoop for block universal gemm.

* Fix overlooked spelling bugs.
This commit is contained in:
Adam Osewski
2025-03-05 23:17:44 +01:00
committed by GitHub
parent 9b51c08bf7
commit 4814db3905
3 changed files with 23 additions and 19 deletions

View File

@@ -68,15 +68,19 @@ struct BlockUniversalGemmAsBsCr
static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
// TODO: Should we have two policies? Interwave & Intrawave ??
// Controls how many MAC clusters (MFMA blocks) we have per wave
// Ie if
// InterWaveSchedulingMacClusters = 1;
// KPerBlock == 32
// WarpGemm::kK = 8
// Then we would group all 4 WarpGemms into single MAC cluster.
// But if we would set InterWaveSchedulingMacClusters = 2, then we would
// split those 4 warp gemms into two groups.
static constexpr index_t InterWaveSchedulingMacClusters = 1;
// should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static constexpr index_t KPack = 8;
static constexpr index_t KPerThread = KIterPerWarp * KPack;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
};
public:
@@ -129,11 +133,12 @@ struct BlockUniversalGemmAsBsCr
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack);
constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterWave>,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto a_block_outer_dstr_encoding =
@@ -153,11 +158,12 @@ struct BlockUniversalGemmAsBsCr
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack);
constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterWave>,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto b_block_outer_dstr_encoding =
@@ -371,11 +377,9 @@ struct BlockUniversalGemmAsBsCr
static constexpr index_t KPerThread = GemmTraits::KPerThread;
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
static constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, GemmTraits::KPack);
// TODO: do we really need this?? Are there any cases when this would be >=1 ??
// Would we need InterWaveSchedulingMacClusters > 1 ???
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack;
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};