set 16x16

This commit is contained in:
coderfeli
2025-04-25 03:09:53 +00:00
parent 2054e165bc
commit f9c29b5ec7
2 changed files with 23 additions and 26 deletions

View File

@@ -155,13 +155,13 @@ using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MXDLPerWave = 2;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t MXDLPerWave = 4;
static constexpr ck::index_t NXDLPerWave = 4;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t Nswizzle = true;
static constexpr ck::index_t Nswizzle = false;
static constexpr bool MulRoutedWeight = false;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
@@ -188,7 +188,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, D2Vec>,
2, 2, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>;
// clang-format on
@@ -201,11 +201,11 @@ int main(int argc, char* argv[])
// GEMM shape
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t K = 6144;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8;
ck::index_t valid_tile_num = 8;
ck::index_t tokens = 128;
ck::index_t sorted_tile_num = 133;
ck::index_t valid_tile_num = 128;
ck::index_t tokens = 8192;
ck::index_t topk = 2;
// ck::index_t tokens = batch * topk;
@@ -268,11 +268,10 @@ int main(int argc, char* argv[])
// int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
max_token_id.mData = {valid_size};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
expert_ids.mData[i] = i / (valid_tile_num / experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;