mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
change code
This commit is contained in:
@@ -1347,7 +1347,7 @@ struct GridwiseMoeGemm
|
||||
c_thread_buf_up,
|
||||
num_k_block_main_loop);
|
||||
|
||||
static_assert(NXdlPerWave == 1, "ONLY 1 now");
|
||||
// static_assert(NXdlPerWave == 1, "ONLY 1 now");
|
||||
// const float scale_gate = scale_b[0];
|
||||
// const float scale_up = scale_b[problem.N * perTokenQuantStride];
|
||||
// static_for<0, c_thread_buf.Size(), 1>{}([&](auto i) {
|
||||
@@ -1397,13 +1397,17 @@ struct GridwiseMoeGemm
|
||||
const index_t m1 = get_warp_local_1d_id() / M1;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
vector_type<int32_t, 4> scale_token_ids;
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(perTokenQuantStride) {
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
scale_token_ids = *c_style_pointer_cast<vector_type<int32_t, 4> *>(p_sorted_token_ids + m_pos);
|
||||
scale_token_ids = *c_style_pointer_cast<const vector_type<int32_t, 4> *>(p_sorted_token_ids + m_pos);
|
||||
}
|
||||
if constexpr (!IsInputGemm)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<vector_type<float, 4> *>(p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
float scale_a = [&]() {
|
||||
@@ -1421,10 +1425,17 @@ struct GridwiseMoeGemm
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.c_thread_desc_.CalculateOffset(make_tuple(m0, n0, m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
auto gate = scale_a * scale_gate * c_thread_buf[cidx];
|
||||
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
gate = gate * math::rcp(1.0 + math::exp(-gate));
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
if constexpr (IsInputGemm) // gu fusion
|
||||
{
|
||||
auto gate = scale_a * scale_gate * c_thread_buf[cidx];
|
||||
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
gate = gate * math::rcp(1.0 + math::exp(-gate));
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf(cidx) = scale_a * scale_gate * c_thread_buf[cidx];
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1527,17 +1538,8 @@ struct GridwiseMoeGemm
|
||||
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
const DDataType* ptr_ = p_ds_grid[i];
|
||||
// hack logic here to support different kind of strides. todo fix it.
|
||||
// ascale t, 1; bscale E, N, 1, move ptr to E
|
||||
// if(i.value == 1)
|
||||
// {
|
||||
// ptr_ +=
|
||||
// expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1);
|
||||
// }
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
@@ -1667,18 +1669,18 @@ struct GridwiseMoeGemm
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
IndexType token_offset = fused_token & 0xffffff;
|
||||
float weight = 1.0f;
|
||||
// float weight = 1.0f;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
else
|
||||
{
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I2];
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
// else
|
||||
// {
|
||||
// const float* p_sorted_weights_2 = p_ds_grid[I2];
|
||||
// weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
// }
|
||||
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
// scatter_weights(m0) = weight;
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
Reference in New Issue
Block a user