Add bias for f16xf4 moe_flatmm

This commit is contained in:
Feng Shijie
2025-08-28 08:02:50 +00:00
parent dd6539f366
commit 5c484a5672
5 changed files with 179 additions and 92 deletions

View File

@@ -14,7 +14,9 @@
namespace ck_tile {
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
class ExpertBias = FlatmmScalePointer<-1>>
struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
{
ck_tile::index_t NumTokens;
@@ -24,6 +26,7 @@ struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
const ck_tile::index_t* p_sorted_expert_ids;
const ck_tile::index_t* p_max_token_id;
const void* p_sorted_expert_weights;
ExpertBias exp_bias;
CK_TILE_HOST MoeFlatmmHostArgs() noexcept = default;
CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t* p_sorted_token_ids_,
@@ -43,8 +46,9 @@ struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ScaleM scale_m_ = {},
ScaleN scale_n_ = {})
ScaleM scale_m_ = {},
ScaleN scale_n_ = {},
ExpertBias exp_bias_ = {})
: ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>(a_ptr_,
b_ptr_,
{}, // d_ptr_array
@@ -65,7 +69,8 @@ struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
p_sorted_token_ids(p_sorted_token_ids_),
p_sorted_expert_ids(p_sorted_expert_ids_),
p_max_token_id(p_max_token_id_),
p_sorted_expert_weights(p_sorted_expert_weights_)
p_sorted_expert_weights(p_sorted_expert_weights_),
exp_bias(exp_bias_)
{
}
};
@@ -91,13 +96,14 @@ struct MoeSilu
struct Swiglu
{
float alpha = 1.702f; // default value used in gpt-oss
float limit = 7.0f; // default value used in gpt-oss
const float alpha;
const float limit;
CK_TILE_HOST_DEVICE
Swiglu() = default;
CK_TILE_HOST_DEVICE
Swiglu(float alpha_, float limit_) : alpha(alpha_), limit(limit_) {}
Swiglu(float alpha_ = 1.702f, float limit_ = 7.0f) // use value in gpt-oss as default
: alpha(alpha_), limit(limit_)
{
}
template <typename T>
CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
@@ -190,7 +196,9 @@ struct MoeFlatmmKernel
static constexpr int WeightPackedSize = numeric_traits<BDataType>::PackedSize;
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
class ExpertBias = FlatmmScalePointer<-1>>
struct MoeFlatmmKernelArgs
{
const ck_tile::index_t* p_sorted_token_ids;
@@ -211,30 +219,34 @@ struct MoeFlatmmKernel
ck_tile::index_t k_batch;
ScaleM scale_m;
ScaleN scale_n;
ExpertBias exp_bias;
};
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
class ExpertBias = FlatmmScalePointer<-1>>
CK_TILE_HOST static constexpr auto
MakeKernelArgs(const MoeFlatmmHostArgs<ScaleM, ScaleN>& hostArgs)
MakeKernelArgs(const MoeFlatmmHostArgs<ScaleM, ScaleN, ExpertBias>& hostArgs)
{
return MoeFlatmmKernelArgs<ScaleM, ScaleN>{hostArgs.p_sorted_token_ids,
hostArgs.p_sorted_expert_ids,
hostArgs.p_max_token_id,
hostArgs.p_sorted_expert_weights,
hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.e_ptr,
hostArgs.NumTokens,
hostArgs.TopK,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C,
hostArgs.k_batch,
hostArgs.scale_m,
hostArgs.scale_n};
return MoeFlatmmKernelArgs<ScaleM, ScaleN, ExpertBias>{hostArgs.p_sorted_token_ids,
hostArgs.p_sorted_expert_ids,
hostArgs.p_max_token_id,
hostArgs.p_sorted_expert_weights,
hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.e_ptr,
hostArgs.NumTokens,
hostArgs.TopK,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C,
hostArgs.k_batch,
hostArgs.scale_m,
hostArgs.scale_n,
hostArgs.exp_bias};
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
@@ -249,8 +261,8 @@ struct MoeFlatmmKernel
{
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
static constexpr auto GridSize(const MoeFlatmmKernelArgs<ScaleM, ScaleN>& kargs)
template <class MoeFlatmmKernelArgs>
static constexpr auto GridSize(const MoeFlatmmKernelArgs& kargs)
{
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
}
@@ -647,8 +659,8 @@ struct MoeFlatmmKernel
return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
}
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs<ScaleM, ScaleN> kargs) const
template <class MoeFlatmmKernelArgs>
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const
{
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
@@ -845,6 +857,7 @@ struct MoeFlatmmKernel
constexpr int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
OutputNumNXdlPerWavePerShuffle;
constexpr bool EnableBias = decltype(kargs.exp_bias)::GranularityMN != -1;
const index_t iMWarp = get_warp_id() / NWave;
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
@@ -855,6 +868,7 @@ struct MoeFlatmmKernel
float vec_scale_B[NRepeat];
float vec_expert_weights[kM0 * kM2 * MRepeat];
float vec_expert_bias[kM0 * kM2 * MRepeat];
const float* expert_weights = static_cast<const float*>(kargs.p_sorted_expert_weights);
@@ -883,6 +897,32 @@ struct MoeFlatmmKernel
});
}
}
if constexpr(MXFP4_Pipeline && EnableBias)
{
if constexpr(IsGateUp)
{
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
vec_expert_bias[i * 2] =
kargs.exp_bias[expert_id * kargs.N + coord_n / 2 + i * NWave * NPerXdl +
iNWarp * NPerXdl + iNLane];
vec_expert_bias[i * 2 + 1] =
kargs.exp_bias[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
});
}
else
{
static_for<0, NRepeat, 2>{}([&](auto i) {
vec_expert_bias[i] =
kargs.exp_bias[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
iNWarp * 2 * NPerXdl + iNLane];
vec_expert_bias[i + 1] =
kargs.exp_bias[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
iNWarp * 2 * NPerXdl + NPerXdl + iNLane];
});
}
}
static_for<0, MRepeat, 1>{}([&](auto i) {
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
@@ -926,15 +966,16 @@ struct MoeFlatmmKernel
c_warp_y_lengths)));
});
if constexpr(!MXFP4_Pipeline)
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
c_warp_y_lengths.product();
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
if constexpr(!MXFP4_Pipeline)
{
gate_tensor
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
@@ -942,11 +983,19 @@ struct MoeFlatmmKernel
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[2 * n_xdl + 1];
});
}
if constexpr(EnableBias)
{
gate_tensor
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] +=
vec_expert_bias[2 * n_xdl];
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] +=
vec_expert_bias[2 * n_xdl + 1];
}
});
});
});
});
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
lds_tile[0].get_thread_buffer().at(idx) =
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
@@ -967,15 +1016,19 @@ struct MoeFlatmmKernel
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
if constexpr(!IsInputGemm)
lds_tile[0]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_expert_weights[m_xdl * kM0 * kM2 + m0 * kM2 + m2];
if constexpr(!MXFP4_Pipeline)
lds_tile[0]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[n_xdl];
if constexpr(EnableBias)
lds_tile[0]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] +=
vec_expert_bias[n_xdl];
if constexpr(!IsInputGemm)
lds_tile[0]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_expert_weights[m_xdl * kM0 * kM2 + m0 * kM2 + m2];
});
});
});
@@ -1041,14 +1094,15 @@ struct MoeFlatmmKernel
c_warp_y_lengths)));
});
if constexpr(!MXFP4_Pipeline)
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
if constexpr(!MXFP4_Pipeline)
{
gate_tensor.get_thread_buffer()[acc_xdl_offset +
m0 * kM2 + m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
@@ -1063,10 +1117,24 @@ struct MoeFlatmmKernel
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
2 * n_xdl + 1];
});
}
if constexpr(EnableBias)
{
gate_tensor.get_thread_buffer()[acc_xdl_offset +
m0 * kM2 + m2] +=
vec_expert_bias[nIter_next *
NumNXdlPerWavePerShuffle +
2 * n_xdl];
up_tensor.get_thread_buffer()[acc_xdl_offset +
m0 * kM2 + m2] +=
vec_expert_bias[nIter_next *
NumNXdlPerWavePerShuffle +
2 * n_xdl + 1];
}
});
});
});
});
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
lds_tile[write_stage].get_thread_buffer().at(idx) =
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
@@ -1090,12 +1158,6 @@ struct MoeFlatmmKernel
c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
if constexpr(!IsInputGemm)
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
m2] *= vec_expert_weights
[mIter_next * NumMXdlPerWavePerShuffle * kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2];
if constexpr(!MXFP4_Pipeline)
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
@@ -1105,6 +1167,19 @@ struct MoeFlatmmKernel
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
n_xdl];
if constexpr(EnableBias)
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
m2] +=
vec_expert_bias[nIter_next *
NumNXdlPerWavePerShuffle +
n_xdl];
if constexpr(!IsInputGemm)
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
m2] *= vec_expert_weights
[mIter_next * NumMXdlPerWavePerShuffle * kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2];
});
});
});