mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
added moe interleaving pipeline (#1712)
* added moe interleaving pipeline * remove redundant code * formater --------- Co-authored-by: root <root@hjbog-srdc-14.amd.com>
This commit is contained in:
@@ -810,21 +810,46 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
|
||||
{
|
||||
using S_ = typename Problem::BlockShape;
|
||||
using T_ = typename Problem::Traits;
|
||||
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
|
||||
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
|
||||
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
|
||||
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
|
||||
T_::PipeInterleave == false)
|
||||
{
|
||||
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16{};
|
||||
// return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
|
||||
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
|
||||
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
|
||||
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
|
||||
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
|
||||
T_::PipeInterleave == false)
|
||||
{
|
||||
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
|
||||
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
|
||||
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
|
||||
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
|
||||
T_::PipeInterleave == true)
|
||||
{
|
||||
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
|
||||
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
|
||||
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
|
||||
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
|
||||
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
|
||||
T_::PipeInterleave == true)
|
||||
{
|
||||
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
|
||||
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -22,7 +22,8 @@ template <bool IsGateOnly_,
|
||||
FusedMoeGemmWeightPermuteEnum PermuteEnum_ =
|
||||
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten,
|
||||
bool PadHiddenSize_ = false,
|
||||
bool PadIntermediateSize_ = false>
|
||||
bool PadIntermediateSize_ = false,
|
||||
bool PipeInterleave_ = true>
|
||||
struct FusedMoeGemmTraits
|
||||
{
|
||||
// Gate+Up or Gate only
|
||||
@@ -32,6 +33,7 @@ struct FusedMoeGemmTraits
|
||||
static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_;
|
||||
static constexpr bool PadHiddenSize = PadHiddenSize_;
|
||||
static constexpr bool PadIntermediateSize = PadIntermediateSize_;
|
||||
static constexpr bool PipeInterleave = PipeInterleave_;
|
||||
};
|
||||
|
||||
// Note: this need to be a bit mask
|
||||
|
||||
Reference in New Issue
Block a user