mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
fix moe gemm1 act bugs
This commit is contained in:
@@ -314,6 +314,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
c_thread_buf_up.Clear();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
|
||||
@@ -1205,6 +1205,7 @@ struct GridwiseMoeGemm
|
||||
return {blockIdx.x, blockIdx.y};
|
||||
}
|
||||
}();
|
||||
|
||||
const index_t block_n_id = block_mn.first;
|
||||
const index_t block_m_id = block_mn.second;
|
||||
const index_t token0 =
|
||||
@@ -1320,7 +1321,7 @@ struct GridwiseMoeGemm
|
||||
KPerBlock);
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
@@ -1468,8 +1469,18 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu)
|
||||
{
|
||||
tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx),
|
||||
c_thread_buf(cidx));
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
auto gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::swiglu)
|
||||
{
|
||||
@@ -1478,7 +1489,12 @@ struct GridwiseMoeGemm
|
||||
PerTokenQuant];
|
||||
auto gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
gate = gate * math::rcp(1.0 + math::exp(-gate));
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
@@ -1524,14 +1540,16 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu)
|
||||
{
|
||||
tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx),
|
||||
c_thread_buf(cidx));
|
||||
auto gate = c_thread_buf[cidx];
|
||||
auto up = c_thread_buf_up[cidx];
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::swiglu)
|
||||
{
|
||||
auto gate = c_thread_buf[cidx];
|
||||
auto up = c_thread_buf_up[cidx];
|
||||
gate = gate * math::rcp(1.0 + math::exp(-gate));
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user