Ck moe hot fix (#1979)

* fix useless code and remove usless oob

* clang format

* fix coredump in e2e test

* fix2

* fix clang format

* fix output oob

* clang format

* rm useless comments

---------

Co-authored-by: coderfeli <coderfeli@163.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
This commit is contained in:
felix
2025-03-19 22:58:27 +08:00
committed by GitHub
parent fdaff5603e
commit 7eaedeb36c
2 changed files with 11 additions and 14 deletions

View File

@@ -1563,12 +1563,8 @@ struct GridwiseMoeGemm
const float* p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats>
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets;
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock]
// & 0xff000000) >> 24;
auto dstidx = sfc_cde_block.GetIndex(access_id);
const index_t c_token_pos =
@@ -1576,7 +1572,9 @@ struct GridwiseMoeGemm
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]];
float weight = token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
: 0.0;
if constexpr(IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
@@ -2074,9 +2072,6 @@ struct GridwiseMoeGemm
StaticallyIndexedArray<index_t, EMRepeats>
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock]
// & 0xff000000) >> 24;
auto dstidx = sfc_cde_block.GetIndex(access_id);
const index_t c_token_pos =
@@ -2084,7 +2079,9 @@ struct GridwiseMoeGemm
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]];
float weight = token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
: 0.0;
if constexpr(IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);

View File

@@ -430,13 +430,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
}
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
dst_offset, true, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
dst_offset, is_dst_valid, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
});
// move coordinate