From 4f5a48c910fa1b24fc623b00470d9e5ad678a3bd Mon Sep 17 00:00:00 2001 From: msaffari-amd Date: Fri, 28 Nov 2025 09:43:01 +0100 Subject: [PATCH] Add validity checks for MoE FlatMM scatter and enable bf16 hardware atomic-add (#3236) * Add validity checks for MoE FlatMM scatter and enable bf16 hardware atomic * correct clang-format * removed unused rtol_atol variable from example code * clang format correction * remove unused varable max_accumulated_value from example [ROCm/composable_kernel commit: f875ab0bbc6ea68a689a688a58f9a53ad12fd536] --- include/ck_tile/core/arch/generic_memory_space_atomic.hpp | 4 ++++ include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp | 8 ++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp index 3d2112f9ca..2fcd76c5e7 100644 --- a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -102,6 +102,9 @@ CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x); template <> CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) { +#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN + __builtin_amdgcn_global_atomic_fadd_v2bf16(c_style_pointer_cast(p_dst), x); +#else union U32BF162_ADDR { uint32_t* u32_a; @@ -128,6 +131,7 @@ CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) new_v = new_.u32; cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); } while(cur_v.u32 != old_v); +#endif } template <> diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 8a4d035e13..b3b34a6da0 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -623,7 +623,7 @@ struct MoeFlatmmKernel { return make_naive_tensor_view( e_ptr, - make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken, + make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens, IsGateUp ? kargs.N / 2 : kargs.N), make_tuple(1, kargs.stride_C), number<1>{}, @@ -1250,6 +1250,8 @@ struct MoeFlatmmKernel constexpr int MPerThread = TileEncodingPattern::Y2; statically_indexed_array, NumMEpiTile> c_scatter_offsets; + statically_indexed_array, NumMEpiTile> + c_scatter_valids; auto c_coord = dram_tile_distribution.calculate_index(); static_for<0, NumMEpiTile, 1>{}([&](auto mIter) { static_for<0, MPerThread, 1>{}([&](auto m0) { @@ -1262,6 +1264,7 @@ struct MoeFlatmmKernel scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); }); }); @@ -1302,7 +1305,8 @@ struct MoeFlatmmKernel c_block_window.get_window_lengths(), c_block_window.get_window_origin(), dram_tile_distribution, - c_scatter_offsets[mIter]); + c_scatter_offsets[mIter], + c_scatter_valids[mIter]); if constexpr(!IsInputGemm || EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)