mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Add bias for f16xf4 moe_flatmm
This commit is contained in:
@@ -14,7 +14,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
class ExpertBias = FlatmmScalePointer<-1>>
|
||||
struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
|
||||
{
|
||||
ck_tile::index_t NumTokens;
|
||||
@@ -24,6 +26,7 @@ struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
|
||||
const ck_tile::index_t* p_sorted_expert_ids;
|
||||
const ck_tile::index_t* p_max_token_id;
|
||||
const void* p_sorted_expert_weights;
|
||||
ExpertBias exp_bias;
|
||||
|
||||
CK_TILE_HOST MoeFlatmmHostArgs() noexcept = default;
|
||||
CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
@@ -43,8 +46,9 @@ struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
|
||||
ck_tile::index_t stride_A_,
|
||||
ck_tile::index_t stride_B_,
|
||||
ck_tile::index_t stride_C_,
|
||||
ScaleM scale_m_ = {},
|
||||
ScaleN scale_n_ = {})
|
||||
ScaleM scale_m_ = {},
|
||||
ScaleN scale_n_ = {},
|
||||
ExpertBias exp_bias_ = {})
|
||||
: ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>(a_ptr_,
|
||||
b_ptr_,
|
||||
{}, // d_ptr_array
|
||||
@@ -65,7 +69,8 @@ struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
|
||||
p_sorted_token_ids(p_sorted_token_ids_),
|
||||
p_sorted_expert_ids(p_sorted_expert_ids_),
|
||||
p_max_token_id(p_max_token_id_),
|
||||
p_sorted_expert_weights(p_sorted_expert_weights_)
|
||||
p_sorted_expert_weights(p_sorted_expert_weights_),
|
||||
exp_bias(exp_bias_)
|
||||
{
|
||||
}
|
||||
};
|
||||
@@ -91,13 +96,14 @@ struct MoeSilu
|
||||
|
||||
struct Swiglu
|
||||
{
|
||||
float alpha = 1.702f; // default value used in gpt-oss
|
||||
float limit = 7.0f; // default value used in gpt-oss
|
||||
const float alpha;
|
||||
const float limit;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
Swiglu() = default;
|
||||
CK_TILE_HOST_DEVICE
|
||||
Swiglu(float alpha_, float limit_) : alpha(alpha_), limit(limit_) {}
|
||||
Swiglu(float alpha_ = 1.702f, float limit_ = 7.0f) // use value in gpt-oss as default
|
||||
: alpha(alpha_), limit(limit_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
|
||||
@@ -190,7 +196,9 @@ struct MoeFlatmmKernel
|
||||
|
||||
static constexpr int WeightPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
class ExpertBias = FlatmmScalePointer<-1>>
|
||||
struct MoeFlatmmKernelArgs
|
||||
{
|
||||
const ck_tile::index_t* p_sorted_token_ids;
|
||||
@@ -211,30 +219,34 @@ struct MoeFlatmmKernel
|
||||
ck_tile::index_t k_batch;
|
||||
ScaleM scale_m;
|
||||
ScaleN scale_n;
|
||||
ExpertBias exp_bias;
|
||||
};
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
class ExpertBias = FlatmmScalePointer<-1>>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
MakeKernelArgs(const MoeFlatmmHostArgs<ScaleM, ScaleN>& hostArgs)
|
||||
MakeKernelArgs(const MoeFlatmmHostArgs<ScaleM, ScaleN, ExpertBias>& hostArgs)
|
||||
{
|
||||
return MoeFlatmmKernelArgs<ScaleM, ScaleN>{hostArgs.p_sorted_token_ids,
|
||||
hostArgs.p_sorted_expert_ids,
|
||||
hostArgs.p_max_token_id,
|
||||
hostArgs.p_sorted_expert_weights,
|
||||
hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.NumTokens,
|
||||
hostArgs.TopK,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.k_batch,
|
||||
hostArgs.scale_m,
|
||||
hostArgs.scale_n};
|
||||
return MoeFlatmmKernelArgs<ScaleM, ScaleN, ExpertBias>{hostArgs.p_sorted_token_ids,
|
||||
hostArgs.p_sorted_expert_ids,
|
||||
hostArgs.p_max_token_id,
|
||||
hostArgs.p_sorted_expert_weights,
|
||||
hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.NumTokens,
|
||||
hostArgs.TopK,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.k_batch,
|
||||
hostArgs.scale_m,
|
||||
hostArgs.scale_n,
|
||||
hostArgs.exp_bias};
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
@@ -249,8 +261,8 @@ struct MoeFlatmmKernel
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
static constexpr auto GridSize(const MoeFlatmmKernelArgs<ScaleM, ScaleN>& kargs)
|
||||
template <class MoeFlatmmKernelArgs>
|
||||
static constexpr auto GridSize(const MoeFlatmmKernelArgs& kargs)
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
|
||||
}
|
||||
@@ -647,8 +659,8 @@ struct MoeFlatmmKernel
|
||||
return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs<ScaleM, ScaleN> kargs) const
|
||||
template <class MoeFlatmmKernelArgs>
|
||||
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const
|
||||
{
|
||||
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
@@ -845,6 +857,7 @@ struct MoeFlatmmKernel
|
||||
|
||||
constexpr int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
|
||||
OutputNumNXdlPerWavePerShuffle;
|
||||
constexpr bool EnableBias = decltype(kargs.exp_bias)::GranularityMN != -1;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
@@ -855,6 +868,7 @@ struct MoeFlatmmKernel
|
||||
float vec_scale_B[NRepeat];
|
||||
|
||||
float vec_expert_weights[kM0 * kM2 * MRepeat];
|
||||
float vec_expert_bias[kM0 * kM2 * MRepeat];
|
||||
|
||||
const float* expert_weights = static_cast<const float*>(kargs.p_sorted_expert_weights);
|
||||
|
||||
@@ -883,6 +897,32 @@ struct MoeFlatmmKernel
|
||||
});
|
||||
}
|
||||
}
|
||||
if constexpr(MXFP4_Pipeline && EnableBias)
|
||||
{
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
|
||||
vec_expert_bias[i * 2] =
|
||||
kargs.exp_bias[expert_id * kargs.N + coord_n / 2 + i * NWave * NPerXdl +
|
||||
iNWarp * NPerXdl + iNLane];
|
||||
vec_expert_bias[i * 2 + 1] =
|
||||
kargs.exp_bias[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
|
||||
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 2>{}([&](auto i) {
|
||||
vec_expert_bias[i] =
|
||||
kargs.exp_bias[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
|
||||
iNWarp * 2 * NPerXdl + iNLane];
|
||||
vec_expert_bias[i + 1] =
|
||||
kargs.exp_bias[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
|
||||
iNWarp * 2 * NPerXdl + NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto i) {
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
@@ -926,15 +966,16 @@ struct MoeFlatmmKernel
|
||||
c_warp_y_lengths)));
|
||||
});
|
||||
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
{
|
||||
gate_tensor
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
@@ -942,11 +983,19 @@ struct MoeFlatmmKernel
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[2 * n_xdl + 1];
|
||||
});
|
||||
}
|
||||
if constexpr(EnableBias)
|
||||
{
|
||||
gate_tensor
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] +=
|
||||
vec_expert_bias[2 * n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] +=
|
||||
vec_expert_bias[2 * n_xdl + 1];
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
});
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
lds_tile[0].get_thread_buffer().at(idx) =
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
@@ -967,15 +1016,19 @@ struct MoeFlatmmKernel
|
||||
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_expert_weights[m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[n_xdl];
|
||||
if constexpr(EnableBias)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] +=
|
||||
vec_expert_bias[n_xdl];
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_expert_weights[m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1041,14 +1094,15 @@ struct MoeFlatmmKernel
|
||||
c_warp_y_lengths)));
|
||||
});
|
||||
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
{
|
||||
gate_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
@@ -1063,10 +1117,24 @@ struct MoeFlatmmKernel
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl + 1];
|
||||
});
|
||||
}
|
||||
if constexpr(EnableBias)
|
||||
{
|
||||
gate_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] +=
|
||||
vec_expert_bias[nIter_next *
|
||||
NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] +=
|
||||
vec_expert_bias[nIter_next *
|
||||
NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl + 1];
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx) =
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
@@ -1090,12 +1158,6 @@ struct MoeFlatmmKernel
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
m2] *= vec_expert_weights
|
||||
[mIter_next * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
@@ -1105,6 +1167,19 @@ struct MoeFlatmmKernel
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
n_xdl];
|
||||
if constexpr(EnableBias)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
m2] +=
|
||||
vec_expert_bias[nIter_next *
|
||||
NumNXdlPerWavePerShuffle +
|
||||
n_xdl];
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
m2] *= vec_expert_weights
|
||||
[mIter_next * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user