mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
optimize w8only gemm
This commit is contained in:
@@ -29,16 +29,16 @@ static constexpr ck::index_t BlockSize = 256;
|
||||
static constexpr ck::index_t Scale_Block_N = 1;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t KPerBlock = 64;
|
||||
static constexpr ck::index_t MPerBlock = 16;
|
||||
static constexpr ck::index_t NPerBlock = 64;
|
||||
static constexpr ck::index_t KPerBlock = 256;
|
||||
|
||||
static constexpr ck::index_t AK1 = 8;
|
||||
static constexpr ck::index_t BK1 = 32;
|
||||
|
||||
static constexpr ck::index_t MPerXDL = 32;
|
||||
static constexpr ck::index_t NPerXDL = 32;
|
||||
static constexpr ck::index_t MXdlPerWave = 4;
|
||||
static constexpr ck::index_t MPerXDL = 16;
|
||||
static constexpr ck::index_t NPerXDL = 16;
|
||||
static constexpr ck::index_t MXdlPerWave = 1;
|
||||
static constexpr ck::index_t NXdlPerWave = 1;
|
||||
|
||||
// clang-format off
|
||||
@@ -52,13 +52,29 @@ using DeviceGemmV2Instance =
|
||||
MPerBlock, NPerBlock, KPerBlock,
|
||||
AK1, BK1,
|
||||
MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
1, 1, S<1, 16, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>;
|
||||
|
||||
using DeviceGemmV2Instance2 =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
64,
|
||||
1, 128,
|
||||
16, 16, 128,
|
||||
8, 16,
|
||||
16, 16, 1, 1,
|
||||
S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
1, 1, S<1, 16, 1, 4>, 4,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4, CDataType, CDataType, PermuteA, PermuteB>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
@@ -276,9 +292,9 @@ bool run_gemm_splitk_example(int argc, char* argv[])
|
||||
ProblemSizeSplitK problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
problem_size.M = 2048;
|
||||
problem_size.N = 1024;
|
||||
problem_size.K = 1024;
|
||||
problem_size.M = 8;
|
||||
problem_size.N = 4096; //1024
|
||||
problem_size.K = 1024; //4096
|
||||
|
||||
config.do_verification = true;
|
||||
config.init_method = 1;
|
||||
|
||||
Reference in New Issue
Block a user