mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
doc the kGroup definition
This commit is contained in:
@@ -58,11 +58,21 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
static constexpr index_t KPerInnerLoop = KPack;
|
||||
static constexpr index_t KGroup =
|
||||
((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
|
||||
(MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
|
||||
? 2
|
||||
: 1;
|
||||
|
||||
static constexpr index_t KGroup = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ComputeDataType>, f8_t>)
|
||||
// On gfx950, we have mfma that required 32 f8 elements as input,
|
||||
// splited into 2 groups of 16 f8 elements.
|
||||
// the 2 groups is not contiguous in the B preshuffed layout.
|
||||
// and we do not want it to be contiguous in the B preshuffled layout
|
||||
// because a memory instruction can only read 16 f8 elements at a time.
|
||||
return ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
|
||||
(MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
|
||||
? 2
|
||||
: 1;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
@@ -171,15 +171,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>{};
|
||||
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
|
||||
static constexpr index_t KGroup = mfma.selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
|
||||
static constexpr index_t KGroup = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, f8_t>)
|
||||
// On gfx950, we have a mfma that required 32 f8 elements as input,
|
||||
// splited into 2 groups of 16 f8 elements.
|
||||
// the 2 groups is not contiguous in the B preshuffed layout.
|
||||
// and we do not want it to be contiguous in the B preshuffled layout
|
||||
// because a memory instruction can only read 16 f8 elements at a time.
|
||||
return mfma.selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
|
||||
static constexpr index_t KPackPerGroup = KPack / KGroup;
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup;
|
||||
|
||||
@@ -175,7 +175,17 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
|
||||
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
static constexpr index_t KGroup = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, f8_t>)
|
||||
// On gfx950, we have a mfma that required 32 f8 elements as input,
|
||||
// splited into 2 groups of 16 f8 elements.
|
||||
// the 2 groups is not contiguous in the B preshuffed layout.
|
||||
// and we do not want it to be contiguous in the B preshuffled layout
|
||||
// because a memory instruction can only read 16 f8 elements at a time.
|
||||
return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
static constexpr index_t KLane =
|
||||
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
|
||||
|
||||
@@ -189,14 +189,20 @@ struct GridwiseMoeGemm
|
||||
static constexpr index_t KLane =
|
||||
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
|
||||
|
||||
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
static constexpr index_t KRepeat = []() {
|
||||
static constexpr index_t KGroup = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, f8_t>)
|
||||
return KPerBlock / KLane / (KPack / KGroup);
|
||||
// On gfx950, we have a mfma that required 32 f8 elements as input,
|
||||
// splited into 2 groups of 16 f8 elements.
|
||||
// the 2 groups is not contiguous in the B preshuffed layout.
|
||||
// and we do not want it to be contiguous in the B preshuffled layout
|
||||
// because a memory instruction can only read 16 f8 elements at a time.
|
||||
return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
else
|
||||
return KPerBlock / KLane / KPack;
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
|
||||
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
// static constexpr index_t NumTokens = 1;
|
||||
|
||||
@@ -195,7 +195,17 @@ struct GridwiseMoeGemmBlockScale
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
|
||||
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
static constexpr index_t KGroup = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, f8_t>)
|
||||
// On gfx950, we have a mfma that required 32 f8 elements as input,
|
||||
// splited into 2 groups of 16 f8 elements.
|
||||
// the 2 groups is not contiguous in the B preshuffed layout.
|
||||
// and we do not want it to be contiguous in the B preshuffled layout
|
||||
// because a memory instruction can only read 16 f8 elements at a time.
|
||||
return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
static constexpr index_t KLane =
|
||||
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
|
||||
|
||||
Reference in New Issue
Block a user