support swiglu activaion and use rcpf to accelerate silu

This commit is contained in:
Feng Shijie
2025-08-26 12:32:29 +00:00
parent d05eed931d
commit 65b702454c
8 changed files with 376 additions and 350 deletions

View File

@@ -205,14 +205,13 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
: col * strideC + scatter_token_id;
if constexpr(MoeGemmKind < 2)
{
AccDataType acc_gate = ActivationOp{}(acc);
C[c_index] =
ck_tile::type_convert<CDataType>(MoeGemmKind == 1 ? acc_gate * acc_up : acc_gate);
C[c_index] = ck_tile::type_convert<CDataType>(
ActivationOp{}(acc, MoeGemmKind == 1 ? acc_up : 1));
}
else
{
CDataType res =
ck_tile::type_convert<CDataType>(ActivationOp{}(acc * expert_weight_ptr[row]));
// moe gemm2 don't use activation.
CDataType res = ck_tile::type_convert<CDataType>(acc * expert_weight_ptr[row]);
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
ck_tile::fp16x2_t,
ck_tile::bf16x2_t>;

View File

@@ -863,7 +863,14 @@ struct Silu
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x)));
if constexpr(std::is_same_v<T, float>)
{
y = x * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x));
}
else
{
y = x * (one / (one + ck_tile::exp(-x)));
}
};
template <>
@@ -1218,7 +1225,7 @@ struct Swish
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
SoftRelu(float alpha = 1.f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1237,7 +1244,7 @@ struct SoftRelu
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
: alpha_(alpha), beta_(beta), gamma_(gamma) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1259,7 +1266,7 @@ struct Power
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1278,7 +1285,7 @@ struct ClippedRelu
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1295,7 +1302,7 @@ struct LeakyRelu
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
Elu(float alpha = 1.f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1312,7 +1319,7 @@ struct Elu
struct Logistic
{
Logistic(float alpha = 1.f) : alpha_(alpha){};
Logistic(float alpha = 1.f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const

View File

@@ -24,17 +24,21 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
{
if (TailNumber::Even == tail_num)
if(TailNumber::Even == tail_num)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Even>{});
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
else if (TailNumber::Odd == tail_num)
else if(TailNumber::Odd == tail_num)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Odd>{});
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
// return run_func(bool_constant<true>{}, integral_constant<TailNumber,
// TailNumber::Empty>{});
}
};
@@ -52,16 +56,17 @@ struct FlatmmPipelineAGmemBGmemCRegV1
using BlockFlatmm =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
static constexpr auto config =
BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
@@ -105,64 +110,67 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) ? DsReadPreload: MIterPerWarp * KIterPerWarp;
static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
? DsReadPreload
: MIterPerWarp * KIterPerWarp;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
/*
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1
defined(USING_MFMA_16x16x16) && defined(ENABLE_FP16) // mi300 fp16 16c 0.5*K1
defined(USING_MFMA_32x32x8) && defined(ENABLE_FP16) // mi300 fp16 32c 0.5*K1
/*
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1
defined(USING_MFMA_16x16x16) && defined(ENABLE_FP16) // mi300 fp16 16c 0.5*K1
defined(USING_MFMA_32x32x8) && defined(ENABLE_FP16) // mi300 fp16 32c 0.5*K1
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP8) // mi350 fp8 32c 2*K1
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP8) // mi350 fp8 64c 2*K1
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP16) // mi350 fp16 16c 1*K1
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP16) // mi350 fp16 32c 1*K1
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP8) // mi350 fp8 32c 2*K1
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP8) // mi350 fp8 64c 2*K1
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP16) // mi350 fp16 16c 1*K1
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP16) // mi350 fp16 32c 1*K1
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
*/
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
*/
// #if (defined(USING_MFMA_16x16x32_F8) || \
// defined(USING_MFMA_32x32x16_F8) || \
// defined(USING_MFMA_16x16x16_F16) || \
// defined(USING_MFMA_32x32x8_F16)) // K1 per Mfma = 0.5
// static constexpr auto mfma_per_wg = 2;
// static constexpr auto dsread_per_wg = 1;
// #elif (defined(USING_MFMA_16x16x32_F16) || \
// defined(USING_MFMA_32x32x16_F16) || \
// defined(USING_MFMA_16x16x128_F4) || \
// defined(USING_MFMA_32x32x64_F4)) // K1 per Mfma = 1
// static constexpr auto mfma_per_wg = 1;
// static constexpr auto dsread_per_wg = 1;
// #elif (defined(USING_MFMA_16x16x128_F8) || \
// defined(USING_MFMA_32x32x64_F8)) // K1 per Mfma = 2
// static constexpr auto mfma_per_wg = 1;
// static constexpr auto dsread_per_wg = 2;
// #endif
#ifdef __gfx942__
static constexpr index_t mfma_per_wg = 2;
#else
static constexpr index_t mfma_per_wg = 1;
#endif
static constexpr index_t dsread_per_wg = WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
// #if (defined(USING_MFMA_16x16x32_F8) || \
// defined(USING_MFMA_32x32x16_F8) || \
// defined(USING_MFMA_16x16x16_F16) || \
// defined(USING_MFMA_32x32x8_F16)) // K1 per Mfma = 0.5
// static constexpr auto mfma_per_wg = 2;
// static constexpr auto dsread_per_wg = 1;
// #elif (defined(USING_MFMA_16x16x32_F16) || \
// defined(USING_MFMA_32x32x16_F16) || \
// defined(USING_MFMA_16x16x128_F4) || \
// defined(USING_MFMA_32x32x64_F4)) // K1 per Mfma = 1
// static constexpr auto mfma_per_wg = 1;
// static constexpr auto dsread_per_wg = 1;
// #elif (defined(USING_MFMA_16x16x128_F8) || \
// defined(USING_MFMA_32x32x64_F8)) // K1 per Mfma = 2
// static constexpr auto mfma_per_wg = 1;
// static constexpr auto dsread_per_wg = 2;
// #endif
#ifdef __gfx942__
static constexpr index_t mfma_per_wg = 2;
#else
static constexpr index_t mfma_per_wg = 1;
#endif
static constexpr index_t dsread_per_wg =
WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0);
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
static constexpr index_t Aload_num_perK = dswrite_num_perK;
static constexpr index_t Aload_rep = dswrite_rep;
static constexpr index_t Aload_rep = dswrite_rep;
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -185,25 +193,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1
return PipelinePolicy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
CK_TILE_HOST_DEVICE static constexpr auto
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
{
// Init inst order
index_t max_data_inst =
dsread_perM > load_perM
? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
: (load_perM > dswrite_perM ? load_perM : dswrite_perM);
index_t max_data_inst = dsread_perM > load_perM
? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
: (load_perM > dswrite_perM ? load_perM : dswrite_perM);
index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM;
index_t round_data_inst = (sum_data_inst + mfma_perM_perK - 1) / mfma_perM_perK;
index_t inst_order[NIterPerWarp * 10];
#pragma unroll
#pragma unroll
for(int idx = 0; idx < NIterPerWarp * 10; idx++)
{
inst_order[idx] = 0;
}
index_t index = 0;
#pragma unroll
#pragma unroll
for(int j = 0; j < max_data_inst; j++)
{
if(dswrite_perM > j)
@@ -223,8 +231,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
}
}
// Schedule IGLP
#pragma unroll
// Schedule IGLP
#pragma unroll
for(int j = 0; j < mfma_perM_perK; j++)
{
index_t inst_idx = 0;
@@ -236,10 +244,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
inst_idx = mfma_perM_perK - 1;
else
inst_idx = mfma_perM_perK - j;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
#pragma unroll
#pragma unroll
for(int r = 0; r < round_data_inst; r++)
{
if(r % 2 == 0)
@@ -284,84 +292,84 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// -1 M6N1: 58 1 - - -
// -1 M6N2: 59 - - 7 -
// -1 M6N3: 60 2 - - -
// -1 M7N0: 61 - - - -
// -1 M7N1: 62 3 - - -
// -1 M7N2: 63 - - 8 -
// -1 M7N3: 64 4 - - -
// 0 M0N0K0: 1 - - - 1
// 0 M0N1: 2 5 - - -
// 0 M0N2: 3 - - - 2
// 0 M0N3: 4 6 - - -
// 0 M1N0: 5 - - - 3
// 0 M1N1: 6 7 - - -
// 0 M1N2: 7 - - - 4
// 0 M1N3: 8 8 - - -
// 0 M2N0: 9 - - - 5
// 0 M2N1: 10 9 - - -
// 0 M2N2: 11 - - - 6
// 0 M2N3: 12 10 - - -
// 0 M3N0: 13 - 1 - 7
// 0 M3N1: 14 11 - - -
// 0 M3N2: 15 - - - 8
// -1 M7N0: 61 - - - -
// -1 M7N1: 62 3 - - -
// -1 M7N2: 63 - - 8 -
// -1 M7N3: 64 4 - - -
// 0 M0N0K0: 1 - - - 1
// 0 M0N1: 2 5 - - -
// 0 M0N2: 3 - - - 2
// 0 M0N3: 4 6 - - -
// 0 M1N0: 5 - - - 3
// 0 M1N1: 6 7 - - -
// 0 M1N2: 7 - - - 4
// 0 M1N3: 8 8 - - -
// 0 M2N0: 9 - - - 5
// 0 M2N1: 10 9 - - -
// 0 M2N2: 11 - - - 6
// 0 M2N3: 12 10 - - -
// 0 M3N0: 13 - 1 - 7
// 0 M3N1: 14 11 - - -
// 0 M3N2: 15 - - - 8
// 0 M3N3: 16 12 - - -
// 0 M4N0: 17 - 2 - -
// 0 M4N1: 18 13 - - -
// 0 M4N2: 19 - - 1 -
// 0 M4N0: 17 - 2 - -
// 0 M4N1: 18 13 - - -
// 0 M4N2: 19 - - 1 -
// 0 M4N3: 20 14 - - -
// 0 M5N0: 21 - 3 - -
// 0 M5N1: 22 15 - - -
// 0 M5N2: 23 - - 2 -
// 0 M5N0: 21 - 3 - -
// 0 M5N1: 22 15 - - -
// 0 M5N2: 23 - - 2 -
// 0 M5N3: 24 16 - - -
// 0 M6N0: 25 - 4 - -
// 0 M6N1: 26 17 - - -
// 0 M6N2: 27 - - 3 -
// 0 M6N0: 25 - 4 - -
// 0 M6N1: 26 17 - - -
// 0 M6N2: 27 - - 3 -
// 0 M6N3: 28 18 - - -
// 0 M7N0: 29 - - - -
// 0 M7N1: 30 19 - - -
// 0 M7N2: 31 - - 4 -
// 0 M7N0: 29 - - - -
// 0 M7N1: 30 19 - - -
// 0 M7N2: 31 - - 4 -
// 0 M7N3: 32 20 - - -
// 0 M0N0K1: 33 - - - 9
// 0 M0N1: 34 21 - - -
// 0 M0N2: 35 - - - 10
// 0 M0N3: 36 22 - - -
// 0 M1N0: 37 - - - 11
// 0 M1N1: 38 23 - - -
// 0 M1N2: 39 - - - 12
// 0 M0N0K1: 33 - - - 9
// 0 M0N1: 34 21 - - -
// 0 M0N2: 35 - - - 10
// 0 M0N3: 36 22 - - -
// 0 M1N0: 37 - - - 11
// 0 M1N1: 38 23 - - -
// 0 M1N2: 39 - - - 12
// 0 M1N3: 40 24 - - -
// 0 M2N0: 41 - - - 13
// 0 M2N1: 42 25 - - -
// 0 M2N2: 43 - - - 14
// 0 M2N3: 44 26 - - -
// 0 M3N0: 45 - 5 - 15
// 0 M3N1: 46 27 - - -
// 0 M3N2: 47 - - - 16
// 0 M2N0: 41 - - - 13
// 0 M2N1: 42 25 - - -
// 0 M2N2: 43 - - - 14
// 0 M2N3: 44 26 - - -
// 0 M3N0: 45 - 5 - 15
// 0 M3N1: 46 27 - - -
// 0 M3N2: 47 - - - 16
// 0 M3N3: 48 28 - - -
// 0 M4N0: 49 - 6 - -
// 0 M4N1: 50 29 - - -
// 0 M4N2: 51 - - 5 -
// 0 M4N0: 49 - 6 - -
// 0 M4N1: 50 29 - - -
// 0 M4N2: 51 - - 5 -
// 0 M4N3: 52 30 - - -
// 0 M5N0: 53 - 7 - -
// 0 M5N1: 54 31 - - -
// 0 M5N2: 55 - - 6 -
// 0 M5N0: 53 - 7 - -
// 0 M5N1: 54 31 - - -
// 0 M5N2: 55 - - 6 -
// 0 M5N3: 56 32 - - -
// 0 M6N0: 57 - 8 - -
// 0 M6N1: 58 1 - - -
// 0 M6N2: 59 - - 7 -
// 0 M6N0: 57 - 8 - -
// 0 M6N1: 58 1 - - -
// 0 M6N2: 59 - - 7 -
// 0 M6N3: 60 2 - - -
// 0 M7N0: 61 - - - -
// 0 M7N1: 62 3 - - -
// 0 M7N2: 63 - - 8 -
// 0 M7N0: 61 - - - -
// 0 M7N1: 62 3 - - -
// 0 M7N2: 63 - - 8 -
// 0 M7N3: 64 4 - - -
#pragma unroll
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
index_t load_perM = 0;
index_t load_perM = 0;
// Calculate ds_read number per M
dsread_perM = dsread_per_wg;
@@ -380,10 +388,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
}
else
{
dswrite_perM =
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
? dswrite_rep
: 0;
dswrite_perM = (dswrite_num_perK -
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
? dswrite_rep
: 0;
}
// Add ds write when ds write data > needed
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
@@ -397,13 +405,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
{
load_perM =
((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep
: 0) +
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep : 0);
: 0) +
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
: 0);
}
else
{
load_perM =
(Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep : 0;
load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0
? Aload_rep
: 0;
}
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
@@ -413,18 +423,18 @@ struct FlatmmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_barrier(0);
}
CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler()
{
#pragma unroll
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
index_t load_perM = 0;
index_t load_perM = 0;
// Calculate ds_read number per M
dsread_perM = dsread_per_wg;
@@ -443,10 +453,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
}
else
{
dswrite_perM =
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
? dswrite_rep
: 0;
dswrite_perM = (dswrite_num_perK -
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
? dswrite_rep
: 0;
}
// Add ds write when ds write data > needed
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
@@ -459,7 +469,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
if(mIter < HalfMIter)
{
load_perM =
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep : 0);
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
: 0);
}
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
@@ -468,20 +479,20 @@ struct FlatmmPipelineAGmemBGmemCRegV1
}
CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler()
{
#pragma unroll
{
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
index_t load_perM = 0;
index_t load_perM = 0;
// Calculate ds_read number per M
if ((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
dsread_perM = dsread_per_wg;
if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
dsread_perM = dsread_per_wg;
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
@@ -507,7 +518,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
"wrong!");
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iMWarp = get_warp_id() / NWarp;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
@@ -517,7 +528,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
__builtin_amdgcn_sched_barrier(0);
// A tile in LDS
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
@@ -525,8 +536,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block_ping = make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
auto a_lds_block_pong = make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
auto a_lds_block_ping =
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
auto a_lds_block_pong =
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
@@ -543,22 +556,22 @@ struct FlatmmPipelineAGmemBGmemCRegV1
auto a_copy_lds_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
// ping-pong window for A LDS
auto a_warp_window_ping_tmp = make_tile_window(
a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
auto a_warp_window_ping_tmp =
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
auto a_warp_window_pong_tmp = make_tile_window(
a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
auto a_warp_window_pong_tmp =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
@@ -569,7 +582,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows_pong;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
@@ -619,7 +632,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1
NIterPerWarp>
b_warp_tensor_pong;
// HEAD
// Prefetch A0
auto a_block_tile = load_tile(a_copy_dram_window);
@@ -632,7 +644,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
@@ -640,19 +652,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// Prefill A0
// if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
// {
// auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
// PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
// shuffle_tile(a_shuffle_tmp, a_block_tile);
// const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
// store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
// }
// else
// {
// store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile));
// }
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
__builtin_amdgcn_sched_barrier(0);
@@ -668,12 +667,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
block_sync_lds();
// preload A00,A10... from lds
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))), m_preload> a_warp_tensor;
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
m_preload>
a_warp_tensor;
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0);
@@ -687,7 +689,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
@@ -701,7 +703,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -709,14 +711,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
@@ -724,14 +728,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
//barrier
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
@@ -745,10 +751,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
// Next K
// prefetch B(2i+2)
@@ -757,12 +764,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// Prefill A(2i+2)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
@@ -782,10 +789,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), b_warp_tensor_pong(nIter)(kIter));
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
@@ -793,14 +802,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
//barrier
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
@@ -814,7 +825,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
@@ -830,7 +842,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
@@ -847,14 +859,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
@@ -862,14 +876,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
//barrier
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
@@ -880,11 +896,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
Last2ndHotLoopScheduler();
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -892,27 +909,31 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), b_warp_tensor_pong(nIter)(kIter));
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
//barrier
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
@@ -930,14 +951,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
@@ -945,14 +968,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
//barrier
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
@@ -974,7 +999,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
[](const ADataType & a) { return a; },
b_flat_dram_block_window_tmp,
num_loop,
p_smem_ping,

View File

@@ -77,11 +77,59 @@ enum class MoeFlatmmKind
kFFN_gemm2,
};
namespace moe {
struct MoeSilu
{
template <typename T>
CK_TILE_HOST_DEVICE T operator()(T gate, T linear = 1) const
{
ck_tile::element_wise::Silu{}(gate, gate);
return gate * linear;
};
};
struct Swiglu
{
float alpha = 1.702f; // default value used in gpt-oss
float limit = 7.0f; // default value used in gpt-oss
CK_TILE_HOST_DEVICE
Swiglu() = default;
CK_TILE_HOST_DEVICE
Swiglu(float alpha_, float limit_) : alpha(alpha_), limit(limit_) {}
template <typename T>
CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
gate = gate < limit ? gate : limit;
linear = linear < limit ? (linear > -limit ? linear : -limit) : limit;
if constexpr(std::is_same_v<T, float>)
{
return gate * __builtin_amdgcn_rcpf(one + ck_tile::exp(alpha * -gate)) * (linear + 1);
}
else
{
return gate * (one / (one + ck_tile::exp(alpha * -gate))) * (linear + 1);
}
}
};
} // namespace moe
template <typename TilePartitioner_,
typename FlatmmPipeline_,
typename EpiloguePipeline_,
MoeFlatmmKind kind,
typename FusedActivation = element_wise::Silu>
typename FusedActivation = moe::MoeSilu>
struct MoeFlatmmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
@@ -900,11 +948,9 @@ struct MoeFlatmmKernel
});
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
gate_tensor.get_thread_buffer().at(idx));
lds_tile[0].get_thread_buffer().at(idx) =
gate_tensor.get_thread_buffer().at(idx) *
up_tensor.get_thread_buffer().at(idx);
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
up_tensor.get_thread_buffer().at(idx));
});
}
else
@@ -937,8 +983,8 @@ struct MoeFlatmmKernel
if constexpr(IsInputGemm)
{
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
ActivationOp{}(lds_tile[0].get_thread_buffer().at(idx),
lds_tile[0].get_thread_buffer().at(idx));
lds_tile[0].get_thread_buffer().at(idx) =
ActivationOp{}(lds_tile[0].get_thread_buffer().at(idx));
});
}
}
@@ -1022,11 +1068,9 @@ struct MoeFlatmmKernel
});
});
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
gate_tensor.get_thread_buffer().at(idx));
lds_tile[write_stage].get_thread_buffer().at(idx) =
gate_tensor.get_thread_buffer().at(idx) *
up_tensor.get_thread_buffer().at(idx);
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
up_tensor.get_thread_buffer().at(idx));
});
}
else
@@ -1068,8 +1112,8 @@ struct MoeFlatmmKernel
if constexpr(IsInputGemm)
{
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
ActivationOp{}(lds_tile[write_stage].get_thread_buffer().at(idx),
lds_tile[write_stage].get_thread_buffer().at(idx));
lds_tile[write_stage].get_thread_buffer().at(idx) = ActivationOp{}(
lds_tile[write_stage].get_thread_buffer().at(idx));
});
}
}