mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK TILE] Fix KIterPerInnerLoop for block gemm. (#1934)
* Fix KIterPerInnerLoop * Fix Kpack and KPerInnerLoop for block universal gemm. * Fix overlooked spelling bugs.
This commit is contained in:
@@ -68,15 +68,19 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
|
||||
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
|
||||
|
||||
// TODO: Should we have two policies? Interwave & Intrawave ??
|
||||
// Controls how many MAC clusters (MFMA blocks) we have per wave
|
||||
// Ie if
|
||||
// InterWaveSchedulingMacClusters = 1;
|
||||
// KPerBlock == 32
|
||||
// WarpGemm::kK = 8
|
||||
// Then we would group all 4 WarpGemms into single MAC cluster.
|
||||
// But if we would set InterWaveSchedulingMacClusters = 2, then we would
|
||||
// split those 4 warp gemms into two groups.
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
|
||||
// should be at least equal to: WarpGemm::Impl::kABKPerLane
|
||||
// and the question is how to assess upper limit or exact value?
|
||||
// TODO: Should we introduce AK1/BK1 parameters ?
|
||||
static constexpr index_t KPack = 8;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * KPack;
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
};
|
||||
|
||||
public:
|
||||
@@ -129,11 +133,12 @@ struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack);
|
||||
constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK;
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterWave>,
|
||||
sequence<KIterInterwave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
@@ -153,11 +158,12 @@ struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack);
|
||||
constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK;
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterWave>,
|
||||
sequence<KIterInterwave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
@@ -371,11 +377,9 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr index_t KPerThread = GemmTraits::KPerThread;
|
||||
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
|
||||
static constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, GemmTraits::KPack);
|
||||
// TODO: do we really need this?? Are there any cases when this would be >=1 ??
|
||||
// Would we need InterWaveSchedulingMacClusters > 1 ???
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack;
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
|
||||
Reference in New Issue
Block a user