updated the codes

This commit is contained in:
mtgu0705
2025-06-04 02:04:28 -05:00
parent 5117e99822
commit 40ed20a30d
5 changed files with 176 additions and 327 deletions

View File

@@ -143,7 +143,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
// clang-format off
@@ -151,15 +151,15 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 64,
MPerBlock, 32, KPerBlock,
ScaleBlockSize, 256,
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
@@ -170,14 +170,14 @@ int main(int argc, char* argv[])
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 2;
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 2;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
@@ -418,7 +418,7 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
#if 1
#if 0
printf("a0_t_k_k:\n");
// for(int t = 0; t < tokens; ++t)
// {
@@ -671,7 +671,7 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
#if 1
#if 0
printf("e_t_n_device_result:\n");
for(int t = 0; t < tokens; ++t)
{