optimize w8only gemm

This commit is contained in:
YilinZhao
2025-04-22 12:22:43 +08:00
parent 441ae92821
commit c6723dc34e
8 changed files with 366 additions and 185 deletions

View File

@@ -29,16 +29,16 @@ static constexpr ck::index_t BlockSize = 256;
static constexpr ck::index_t Scale_Block_N = 1;
static constexpr ck::index_t Scale_Block_K = 128;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t KPerBlock = 64;
static constexpr ck::index_t MPerBlock = 16;
static constexpr ck::index_t NPerBlock = 64;
static constexpr ck::index_t KPerBlock = 256;
static constexpr ck::index_t AK1 = 8;
static constexpr ck::index_t BK1 = 32;
static constexpr ck::index_t MPerXDL = 32;
static constexpr ck::index_t NPerXDL = 32;
static constexpr ck::index_t MXdlPerWave = 4;
static constexpr ck::index_t MPerXDL = 16;
static constexpr ck::index_t NPerXDL = 16;
static constexpr ck::index_t MXdlPerWave = 1;
static constexpr ck::index_t NXdlPerWave = 1;
// clang-format off
@@ -52,13 +52,29 @@ using DeviceGemmV2Instance =
MPerBlock, NPerBlock, KPerBlock,
AK1, BK1,
MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, 8,
1, 1, S<1, 16, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>;
using DeviceGemmV2Instance2 =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
64,
1, 128,
16, 16, 128,
8, 16,
16, 16, 1, 1,
S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4, CDataType, CDataType, PermuteA, PermuteB>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
@@ -276,9 +292,9 @@ bool run_gemm_splitk_example(int argc, char* argv[])
ProblemSizeSplitK problem_size;
ExecutionConfig config;
problem_size.M = 2048;
problem_size.N = 1024;
problem_size.K = 1024;
problem_size.M = 8;
problem_size.N = 4096; //1024
problem_size.K = 1024; //4096
config.do_verification = true;
config.init_method = 1;