From 77f9a0a615c6015789b70a8432b48260fc3282ad Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Tue, 9 Dec 2025 17:54:55 +0800 Subject: [PATCH] fix a16w4 moe bugs (#3373) * fix valid mask bug * update format [ROCm/composable_kernel commit: 6f0966e1e9fca5c513d16a729237d676b583e266] --- include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index b3b34a6da0..7104547247 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1259,12 +1259,12 @@ struct MoeFlatmmKernel auto fused_token = kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0] - index_t scatter_token_id = fused_token & token_id_mask; + index_t scatter_token_id = fused_token & token_id_mask; + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); if constexpr(IsInputGemm) scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; - c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); }); });