fix moe gemm1 act bugs

This commit is contained in:
lalala-sh
2025-04-08 14:30:47 +08:00
committed by GitHub
2 changed files with 26 additions and 7 deletions

View File

@@ -314,6 +314,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
// Initialize C
c_thread_buf.Clear();
c_thread_buf_up.Clear();
__builtin_amdgcn_sched_barrier(0);

View File

@@ -1205,6 +1205,7 @@ struct GridwiseMoeGemm
return {blockIdx.x, blockIdx.y};
}
}();
const index_t block_n_id = block_mn.first;
const index_t block_m_id = block_mn.second;
const index_t token0 =
@@ -1320,7 +1321,7 @@ struct GridwiseMoeGemm
KPerBlock);
if constexpr(IsInputGemm)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
@@ -1468,8 +1469,18 @@ struct GridwiseMoeGemm
}
else if(ActivationOperation == Activation::gelu)
{
tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx),
c_thread_buf(cidx));
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
auto gate = scale_a * scale_b * c_thread_buf[cidx];
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::swiglu)
{
@@ -1478,7 +1489,12 @@ struct GridwiseMoeGemm
PerTokenQuant];
auto gate = scale_a * scale_b * c_thread_buf[cidx];
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
gate = gate * math::rcp(1.0 + math::exp(-gate));
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
}
@@ -1524,14 +1540,16 @@ struct GridwiseMoeGemm
}
else if(ActivationOperation == Activation::gelu)
{
tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx),
c_thread_buf(cidx));
auto gate = c_thread_buf[cidx];
auto up = c_thread_buf_up[cidx];
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::swiglu)
{
auto gate = c_thread_buf[cidx];
auto up = c_thread_buf_up[cidx];
gate = gate * math::rcp(1.0 + math::exp(-gate));
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
}