mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
test expert = 8 and function pass.
This commit is contained in:
@@ -152,7 +152,8 @@ using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
#if 0
|
||||
#if 1
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
@@ -168,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
64, MPerBlock, 16, KPerBlock,
|
||||
256, MPerBlock, 128, KPerBlock,
|
||||
AK1, BK1,
|
||||
MNPerXDL, MNPerXDL,
|
||||
MXDLPerWave, 1,
|
||||
@@ -208,12 +209,12 @@ int main(int argc, char* argv[])
|
||||
// GEMM shape
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 8192;
|
||||
ck::index_t experts = 1;
|
||||
ck::index_t sorted_tile_num = 1;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 8;
|
||||
ck::index_t sorted_tile_size = MPerBlock;
|
||||
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
|
||||
// ck::index_t tokens = 128;
|
||||
ck::index_t tokens = 16;
|
||||
ck::index_t tokens = 128;
|
||||
// ck::index_t tokens = 16;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user