mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
support swiglu activaion and use rcpf to accelerate silu
This commit is contained in:
@@ -77,11 +77,59 @@ enum class MoeFlatmmKind
|
||||
kFFN_gemm2,
|
||||
};
|
||||
|
||||
namespace moe {
|
||||
|
||||
struct MoeSilu
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE T operator()(T gate, T linear = 1) const
|
||||
{
|
||||
ck_tile::element_wise::Silu{}(gate, gate);
|
||||
return gate * linear;
|
||||
};
|
||||
};
|
||||
|
||||
struct Swiglu
|
||||
{
|
||||
float alpha = 1.702f; // default value used in gpt-oss
|
||||
float limit = 7.0f; // default value used in gpt-oss
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
Swiglu() = default;
|
||||
CK_TILE_HOST_DEVICE
|
||||
Swiglu(float alpha_, float limit_) : alpha(alpha_), limit(limit_) {}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
|
||||
{
|
||||
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, int32_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
constexpr T one = type_convert<T>(1);
|
||||
|
||||
gate = gate < limit ? gate : limit;
|
||||
linear = linear < limit ? (linear > -limit ? linear : -limit) : limit;
|
||||
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
{
|
||||
return gate * __builtin_amdgcn_rcpf(one + ck_tile::exp(alpha * -gate)) * (linear + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return gate * (one / (one + ck_tile::exp(alpha * -gate))) * (linear + 1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace moe
|
||||
|
||||
template <typename TilePartitioner_,
|
||||
typename FlatmmPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
MoeFlatmmKind kind,
|
||||
typename FusedActivation = element_wise::Silu>
|
||||
typename FusedActivation = moe::MoeSilu>
|
||||
struct MoeFlatmmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
@@ -900,11 +948,9 @@ struct MoeFlatmmKernel
|
||||
});
|
||||
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
gate_tensor.get_thread_buffer().at(idx));
|
||||
lds_tile[0].get_thread_buffer().at(idx) =
|
||||
gate_tensor.get_thread_buffer().at(idx) *
|
||||
up_tensor.get_thread_buffer().at(idx);
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
up_tensor.get_thread_buffer().at(idx));
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -937,8 +983,8 @@ struct MoeFlatmmKernel
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
ActivationOp{}(lds_tile[0].get_thread_buffer().at(idx),
|
||||
lds_tile[0].get_thread_buffer().at(idx));
|
||||
lds_tile[0].get_thread_buffer().at(idx) =
|
||||
ActivationOp{}(lds_tile[0].get_thread_buffer().at(idx));
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1022,11 +1068,9 @@ struct MoeFlatmmKernel
|
||||
});
|
||||
});
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
gate_tensor.get_thread_buffer().at(idx));
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx) =
|
||||
gate_tensor.get_thread_buffer().at(idx) *
|
||||
up_tensor.get_thread_buffer().at(idx);
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
up_tensor.get_thread_buffer().at(idx));
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -1068,8 +1112,8 @@ struct MoeFlatmmKernel
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
ActivationOp{}(lds_tile[write_stage].get_thread_buffer().at(idx),
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx));
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx) = ActivationOp{}(
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user