mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4297 (commit 5ff580c)
moe flatmm xcd remap MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit co-authors: @Chi-Chu319 @juuso-oskari Added XCD remapping for flatmm moe <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/tianxiwu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/tianxiwu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <style> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Arial, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} --> </style> </head> <body link="#467886" vlink="#96607D"> batch | Mixtral (tflops, wip_355) | Mixtral-7B (tflops, our branch) | perf boost -- | -- | -- | -- 64 | 865.424 | 995.455 | 15.0% 256 | 886.336 | 1020.96 | 15.2% 1024 | 890.808 | 1022.53 | 14.8% </body> </html>
This commit is contained in:
committed by
assistant-librarian[bot]
parent
5cb8109535
commit
0a2b6c4bcd
@@ -901,16 +901,25 @@ struct MoeFlatmmKernel
|
||||
template <class MoeFlatmmKernelArgs>
|
||||
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const
|
||||
{
|
||||
int partition_idx = blockIdx.x;
|
||||
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
// total number of tokens: sorted tokens + delimiter tokens + trailing padding tokens
|
||||
// we launch the grid based on the total number of tokens which needs to be static
|
||||
int partition_idx = blockIdx.x;
|
||||
auto max_token_id = kargs.p_max_token_id[0]; // sorted tokens + delimiter tokens
|
||||
int total_valid_tile_cnt = TilePartitioner::GridSize(max_token_id, kargs.N);
|
||||
auto tilePartitioner = TilePartitioner{max_token_id, kargs.N};
|
||||
do
|
||||
{
|
||||
if(partition_idx >= total_valid_tile_cnt)
|
||||
{
|
||||
return; // early exit for trailing padding tokens
|
||||
}
|
||||
partition_idx = tilePartitioner.RemapXCD(partition_idx, total_valid_tile_cnt);
|
||||
const auto [block_offset_m, block_offset_n] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
||||
tilePartitioner.GetOutputTileIndex(partition_idx);
|
||||
|
||||
this->operator()(kargs, block_offset_m, block_offset_n);
|
||||
partition_idx += gridDim.x;
|
||||
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
|
||||
} while(UsePersistentKernel && partition_idx < total_valid_tile_cnt);
|
||||
}
|
||||
|
||||
template <class MoeFlatmmKernelArgs>
|
||||
@@ -920,7 +929,6 @@ struct MoeFlatmmKernel
|
||||
// const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
const index_t coord_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t coord_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
const index_t max_token_id = kargs.p_max_token_id[0];
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_ping[GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[GetSmemPongSize()];
|
||||
@@ -948,8 +956,6 @@ struct MoeFlatmmKernel
|
||||
return gather_token_id;
|
||||
};
|
||||
|
||||
if(coord_m >= max_token_id)
|
||||
return;
|
||||
static_for<0, DramMRepeat, 1>{}([&](auto m0) {
|
||||
const auto row_idx =
|
||||
coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[I0];
|
||||
|
||||
Reference in New Issue
Block a user