support swiglu activaion and use rcpf to accelerate silu

This commit is contained in:
Feng Shijie
2025-08-26 12:32:29 +00:00
parent d05eed931d
commit 65b702454c
8 changed files with 376 additions and 350 deletions

View File

@@ -77,11 +77,59 @@ enum class MoeFlatmmKind
kFFN_gemm2,
};
namespace moe {
struct MoeSilu
{
template <typename T>
CK_TILE_HOST_DEVICE T operator()(T gate, T linear = 1) const
{
ck_tile::element_wise::Silu{}(gate, gate);
return gate * linear;
};
};
struct Swiglu
{
float alpha = 1.702f; // default value used in gpt-oss
float limit = 7.0f; // default value used in gpt-oss
CK_TILE_HOST_DEVICE
Swiglu() = default;
CK_TILE_HOST_DEVICE
Swiglu(float alpha_, float limit_) : alpha(alpha_), limit(limit_) {}
template <typename T>
CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
gate = gate < limit ? gate : limit;
linear = linear < limit ? (linear > -limit ? linear : -limit) : limit;
if constexpr(std::is_same_v<T, float>)
{
return gate * __builtin_amdgcn_rcpf(one + ck_tile::exp(alpha * -gate)) * (linear + 1);
}
else
{
return gate * (one / (one + ck_tile::exp(alpha * -gate))) * (linear + 1);
}
}
};
} // namespace moe
template <typename TilePartitioner_,
typename FlatmmPipeline_,
typename EpiloguePipeline_,
MoeFlatmmKind kind,
typename FusedActivation = element_wise::Silu>
typename FusedActivation = moe::MoeSilu>
struct MoeFlatmmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
@@ -900,11 +948,9 @@ struct MoeFlatmmKernel
});
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
gate_tensor.get_thread_buffer().at(idx));
lds_tile[0].get_thread_buffer().at(idx) =
gate_tensor.get_thread_buffer().at(idx) *
up_tensor.get_thread_buffer().at(idx);
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
up_tensor.get_thread_buffer().at(idx));
});
}
else
@@ -937,8 +983,8 @@ struct MoeFlatmmKernel
if constexpr(IsInputGemm)
{
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
ActivationOp{}(lds_tile[0].get_thread_buffer().at(idx),
lds_tile[0].get_thread_buffer().at(idx));
lds_tile[0].get_thread_buffer().at(idx) =
ActivationOp{}(lds_tile[0].get_thread_buffer().at(idx));
});
}
}
@@ -1022,11 +1068,9 @@ struct MoeFlatmmKernel
});
});
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
gate_tensor.get_thread_buffer().at(idx));
lds_tile[write_stage].get_thread_buffer().at(idx) =
gate_tensor.get_thread_buffer().at(idx) *
up_tensor.get_thread_buffer().at(idx);
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
up_tensor.get_thread_buffer().at(idx));
});
}
else
@@ -1068,8 +1112,8 @@ struct MoeFlatmmKernel
if constexpr(IsInputGemm)
{
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
ActivationOp{}(lds_tile[write_stage].get_thread_buffer().at(idx),
lds_tile[write_stage].get_thread_buffer().at(idx));
lds_tile[write_stage].get_thread_buffer().at(idx) = ActivationOp{}(
lds_tile[write_stage].get_thread_buffer().at(idx));
});
}
}