mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Update moe_flatmm_kernel to manage OOB
This commit is contained in:
@@ -1250,6 +1250,8 @@ struct MoeFlatmmKernel
|
||||
constexpr int MPerThread = TileEncodingPattern::Y2;
|
||||
statically_indexed_array<statically_indexed_array<index_t, MPerThread>, NumMEpiTile>
|
||||
c_scatter_offsets;
|
||||
statically_indexed_array<statically_indexed_array<bool, MPerThread>, NumMEpiTile>
|
||||
c_scatter_valids;
|
||||
auto c_coord = dram_tile_distribution.calculate_index();
|
||||
static_for<0, NumMEpiTile, 1>{}([&](auto mIter) {
|
||||
static_for<0, MPerThread, 1>{}([&](auto m0) {
|
||||
@@ -1262,6 +1264,7 @@ struct MoeFlatmmKernel
|
||||
scatter_token_id =
|
||||
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
|
||||
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
|
||||
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1302,7 +1305,8 @@ struct MoeFlatmmKernel
|
||||
c_block_window.get_window_lengths(),
|
||||
c_block_window.get_window_origin(),
|
||||
dram_tile_distribution,
|
||||
c_scatter_offsets[mIter]);
|
||||
c_scatter_offsets[mIter],
|
||||
c_scatter_valids[mIter]);
|
||||
|
||||
if constexpr(!IsInputGemm ||
|
||||
EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)
|
||||
|
||||
Reference in New Issue
Block a user