change code

This commit is contained in:
coderfeli
2025-03-25 09:44:32 +00:00
parent 0d266bfd65
commit 234b8d415c

View File

@@ -1347,7 +1347,7 @@ struct GridwiseMoeGemm
c_thread_buf_up,
num_k_block_main_loop);
static_assert(NXdlPerWave == 1, "ONLY 1 now");
// static_assert(NXdlPerWave == 1, "ONLY 1 now");
// const float scale_gate = scale_b[0];
// const float scale_up = scale_b[problem.N * perTokenQuantStride];
// static_for<0, c_thread_buf.Size(), 1>{}([&](auto i) {
@@ -1397,13 +1397,17 @@ struct GridwiseMoeGemm
const index_t m1 = get_warp_local_1d_id() / M1;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
vector_type<int32_t, 4> scale_token_ids;
vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
if constexpr(perTokenQuantStride) {
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
scale_token_ids = *c_style_pointer_cast<vector_type<int32_t, 4> *>(p_sorted_token_ids + m_pos);
scale_token_ids = *c_style_pointer_cast<const vector_type<int32_t, 4> *>(p_sorted_token_ids + m_pos);
}
if constexpr (!IsInputGemm)
{
topk_weights = *c_style_pointer_cast<vector_type<float, 4> *>(p_ds_grid[I2] + m_pos);
}
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
float scale_a = [&]() {
@@ -1421,10 +1425,17 @@ struct GridwiseMoeGemm
constexpr index_t c_offset =
blockwise_gemm_pipeline.c_thread_desc_.CalculateOffset(make_tuple(m0, n0, m2 * M4 + m4));
constexpr auto cidx = Number<c_offset>{};
auto gate = scale_a * scale_gate * c_thread_buf[cidx];
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
gate = gate * math::rcp(1.0 + math::exp(-gate));
c_thread_buf(cidx) = gate * up;
if constexpr (IsInputGemm) // gu fusion
{
auto gate = scale_a * scale_gate * c_thread_buf[cidx];
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
gate = gate * math::rcp(1.0 + math::exp(-gate));
c_thread_buf(cidx) = gate * up;
}
else
{
c_thread_buf(cidx) = scale_a * scale_gate * c_thread_buf[cidx];
}
});
});
});
@@ -1527,17 +1538,8 @@ struct GridwiseMoeGemm
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
const DDataType* ptr_ = p_ds_grid[i];
// hack logic here to support different kind of strides. todo fix it.
// ascale t, 1; bscale E, N, 1, move ptr to E
// if(i.value == 1)
// {
// ptr_ +=
// expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1);
// }
return make_dynamic_buffer<AddressSpaceEnum::Global>(
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
@@ -1667,18 +1669,18 @@ struct GridwiseMoeGemm
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
IndexType token_offset = fused_token & 0xffffff;
float weight = 1.0f;
// float weight = 1.0f;
if constexpr(IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
else
{
const float* p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
// else
// {
// const float* p_sorted_weights_2 = p_ds_grid[I2];
// weight = weight * p_sorted_weights_2[c_token_pos + m0];
// }
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
scatter_weights(m0) = weight;
// scatter_weights(m0) = weight;
});
block_sync_lds();