diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index a9c21681d0..d1cbb87b62 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1250,6 +1250,8 @@ struct MoeFlatmmKernel constexpr int MPerThread = TileEncodingPattern::Y2; statically_indexed_array, NumMEpiTile> c_scatter_offsets; + statically_indexed_array, NumMEpiTile> + c_scatter_valids; auto c_coord = dram_tile_distribution.calculate_index(); static_for<0, NumMEpiTile, 1>{}([&](auto mIter) { static_for<0, MPerThread, 1>{}([&](auto m0) { @@ -1262,6 +1264,7 @@ struct MoeFlatmmKernel scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); }); }); @@ -1302,7 +1305,8 @@ struct MoeFlatmmKernel c_block_window.get_window_lengths(), c_block_window.get_window_origin(), dram_tile_distribution, - c_scatter_offsets[mIter]); + c_scatter_offsets[mIter], + c_scatter_valids[mIter]); if constexpr(!IsInputGemm || EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)