Add validity checks for MoE FlatMM scatter and enable bf16 hardware atomic-add (#3236)

* Add validity checks for MoE FlatMM scatter and enable bf16 hardware atomic

* correct clang-format

* removed unused rtol_atol variable from example code

* clang format correction

* remove unused varable max_accumulated_value from example
This commit is contained in:
msaffari-amd
2025-11-28 09:43:01 +01:00
committed by GitHub
parent 30727c48fc
commit f875ab0bbc
2 changed files with 10 additions and 2 deletions

View File

@@ -623,7 +623,7 @@ struct MoeFlatmmKernel
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken,
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens,
IsGateUp ? kargs.N / 2 : kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
@@ -1250,6 +1250,8 @@ struct MoeFlatmmKernel
constexpr int MPerThread = TileEncodingPattern::Y2;
statically_indexed_array<statically_indexed_array<index_t, MPerThread>, NumMEpiTile>
c_scatter_offsets;
statically_indexed_array<statically_indexed_array<bool, MPerThread>, NumMEpiTile>
c_scatter_valids;
auto c_coord = dram_tile_distribution.calculate_index();
static_for<0, NumMEpiTile, 1>{}([&](auto mIter) {
static_for<0, MPerThread, 1>{}([&](auto m0) {
@@ -1262,6 +1264,7 @@ struct MoeFlatmmKernel
scatter_token_id =
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
});
});
@@ -1302,7 +1305,8 @@ struct MoeFlatmmKernel
c_block_window.get_window_lengths(),
c_block_window.get_window_origin(),
dram_tile_distribution,
c_scatter_offsets[mIter]);
c_scatter_offsets[mIter],
c_scatter_valids[mIter]);
if constexpr(!IsInputGemm ||
EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)