[rocm-libraries] ROCm/rocm-libraries#4297 (commit 5ff580c)

moe flatmm xcd remap
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

co-authors: @Chi-Chu319 @juuso-oskari

Added XCD remapping for flatmm moe
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File

href="file:///C:/Users/tianxiwu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List

href="file:///C:/Users/tianxiwu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<style>
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Arial, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
-->
</style>
</head>

<body link="#467886" vlink="#96607D">

batch | Mixtral (tflops, wip_355) | Mixtral-7B  (tflops, our branch) |
perf boost
-- | -- | -- | --
64 | 865.424 | 995.455 | 15.0%
256 | 886.336 | 1020.96 | 15.2%
1024 | 890.808 | 1022.53 | 14.8%

</body>

</html>
This commit is contained in:
Tianxing Wu
2026-02-18 19:33:24 +00:00
committed by assistant-librarian[bot]
parent 5cb8109535
commit 0a2b6c4bcd
2 changed files with 61 additions and 7 deletions

View File

@@ -901,16 +901,25 @@ struct MoeFlatmmKernel
template <class MoeFlatmmKernelArgs>
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const
{
int partition_idx = blockIdx.x;
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
// total number of tokens: sorted tokens + delimiter tokens + trailing padding tokens
// we launch the grid based on the total number of tokens which needs to be static
int partition_idx = blockIdx.x;
auto max_token_id = kargs.p_max_token_id[0]; // sorted tokens + delimiter tokens
int total_valid_tile_cnt = TilePartitioner::GridSize(max_token_id, kargs.N);
auto tilePartitioner = TilePartitioner{max_token_id, kargs.N};
do
{
if(partition_idx >= total_valid_tile_cnt)
{
return; // early exit for trailing padding tokens
}
partition_idx = tilePartitioner.RemapXCD(partition_idx, total_valid_tile_cnt);
const auto [block_offset_m, block_offset_n] =
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
tilePartitioner.GetOutputTileIndex(partition_idx);
this->operator()(kargs, block_offset_m, block_offset_n);
partition_idx += gridDim.x;
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
} while(UsePersistentKernel && partition_idx < total_valid_tile_cnt);
}
template <class MoeFlatmmKernelArgs>
@@ -920,7 +929,6 @@ struct MoeFlatmmKernel
// const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const index_t coord_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t coord_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const index_t max_token_id = kargs.p_max_token_id[0];
// allocate LDS
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
@@ -948,8 +956,6 @@ struct MoeFlatmmKernel
return gather_token_id;
};
if(coord_m >= max_token_id)
return;
static_for<0, DramMRepeat, 1>{}([&](auto m0) {
const auto row_idx =
coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[I0];

View File

@@ -265,6 +265,54 @@ struct GemmSpatiallyLocalTilePartitioner
return integer_divide_ceil(K, KPerBlock);
}
/**
* @brief XCDs access ids in round robin format, this function remaps the 1D ids to continguous
* XCD segments
*
* @param block_1d_id grid 1D id
* @param total_num_tiles size of the 1D grid
* @param NUM_XCDS number of XCDs
* @return index_t The id after XCD remap
*/
CK_TILE_HOST_DEVICE static auto
RemapXCD(index_t block_1d_id, index_t total_num_tiles, index_t NUM_XCDS = 8) noexcept -> index_t
{
// Number of ids per XCD in the new arrangement
index_t ids_per_xcd = (total_num_tiles + NUM_XCDS - 1) / NUM_XCDS;
// When total_num_tiles cannot divide NUM_XCDS, some xcds will have
// ids_per_xcd ids, the other will have ids_per_xcd - 1 ids.
// We calculate the number of xcds that have ids_per_xcd ids as tall_xcds
index_t tall_xcds = total_num_tiles % NUM_XCDS;
tall_xcds = (tall_xcds == 0) ? NUM_XCDS : tall_xcds;
// Compute current XCD and local id within the XCD
index_t xcd = block_1d_id % NUM_XCDS;
index_t local_id = block_1d_id / NUM_XCDS;
// Calculate new id based on the new grouping
if(xcd < tall_xcds)
{
block_1d_id = xcd * ids_per_xcd + local_id;
}
else
{
block_1d_id =
tall_xcds * ids_per_xcd + (xcd - tall_xcds) * (ids_per_xcd - 1) + local_id;
}
/**
* original ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
* XCD 0 gets: [0, 8], XCD 1 gets: [1, 9], ...
*
* post-remap ids: [0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15]
* XCD 0 gets: [0, 1], XCD 1 gets: [2, 3], ...
*
* after remap the ids are continguous on each XCD
*/
return block_1d_id;
}
/**
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*