merge flatmm pipe v0 from dteng_flatmm_opt

This commit is contained in:
valarLip
2025-07-23 08:44:12 +00:00
parent 6dacf833da
commit 89fa639207
5 changed files with 987 additions and 105 deletions

View File

@@ -1,15 +1,13 @@
add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef -Wno-unused-variable -Wno-unused-parameter)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32_F8=1 -Wno-unused-local-typedef)
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16_F8=1 -Wno-unused-local-typedef)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x128_F8=1 -Wno-unused-local-typedef)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0")
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps -Wno-nrvo)
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0")
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --disable-schedmodel-in-sched-mi=1 -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental -mllvm --misched-bottomup=1")
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})

View File

@@ -63,7 +63,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV0<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
@@ -90,7 +90,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
ck_tile::FlatmmPipelineAGmemBGmemCRegV0<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,

View File

@@ -86,7 +86,7 @@ struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
static constexpr int kBlockPerCu = 1;
static constexpr int kBlockPerCu = 2;
};
template <typename ADataType>
@@ -167,119 +167,119 @@ struct is_8bit_type
{
};
template <typename DataType>
struct GemmConfig
{
#if defined(USING_MFMA_16x16x128_F8) //MI350 FP8 16X16
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
// template <typename DataType>
// struct GemmConfig
// {
// #if defined(USING_MFMA_16x16x128_F8) //MI350 FP8 16X16
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 256;
// static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
#elif defined(USING_MFMA_32x32x64_F8) //MI350 FP8 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
// static constexpr ck_tile::index_t K_Warp_Tile = 128;
// #elif defined(USING_MFMA_32x32x64_F8) //MI350 FP8 32X32 (need tune)
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 128;
// static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
#elif defined(USING_MFMA_16x16x32_F16) //MI350 FP16 16X16 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
// static constexpr ck_tile::index_t M_Warp_Tile = 32;
// static constexpr ck_tile::index_t N_Warp_Tile = 32;
// static constexpr ck_tile::index_t K_Warp_Tile = 64;
// #elif defined(USING_MFMA_16x16x32_F16) //MI350 FP16 16X16 (need tune)
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 128;
// static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_32x32x16_F16) //MI350 FP16 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
// static constexpr ck_tile::index_t K_Warp_Tile = 32;
// #elif defined(USING_MFMA_32x32x16_F16) //MI350 FP16 32X32 (need tune)
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 128;
// static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
#elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
// static constexpr ck_tile::index_t M_Warp_Tile = 32;
// static constexpr ck_tile::index_t N_Warp_Tile = 32;
// static constexpr ck_tile::index_t K_Warp_Tile = 16;
// #elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16
// static constexpr ck_tile::index_t M_Tile = 16;
// static constexpr ck_tile::index_t N_Tile = 64;
// static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
#elif defined(USING_MFMA_32x32x16_F8) //MI300 FP8 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
// static constexpr ck_tile::index_t K_Warp_Tile = 64;
// #elif defined(USING_MFMA_32x32x16_F8) //MI300 FP8 32X32 (need tune)
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 256;
// static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 8;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 8;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_16x16x16_F16) //MI300 FP16 16X16 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
// static constexpr ck_tile::index_t M_Warp_Tile = 32;
// static constexpr ck_tile::index_t N_Warp_Tile = 32;
// static constexpr ck_tile::index_t K_Warp_Tile = 32;
// #elif defined(USING_MFMA_16x16x16_F16) //MI300 FP16 16X16 (need tune)
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 128;
// static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_32x32x8_F16) //MI300 FP16 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
// static constexpr ck_tile::index_t K_Warp_Tile = 32;
// #elif defined(USING_MFMA_32x32x8_F16) //MI300 FP16 32X32 (need tune)
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 128;
// static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
// static constexpr ck_tile::index_t M_Warp_Tile = 32;
// static constexpr ck_tile::index_t N_Warp_Tile = 32;
// static constexpr ck_tile::index_t K_Warp_Tile = 16;
// #else
// static constexpr ck_tile::index_t M_Tile = 128;
// static constexpr ck_tile::index_t N_Tile = 256;
// static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
// static constexpr ck_tile::index_t M_Warp = 1;
// static constexpr ck_tile::index_t N_Warp = 4;
// static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
#endif
};
// static constexpr ck_tile::index_t M_Warp_Tile = 16;
// static constexpr ck_tile::index_t N_Warp_Tile = 16;
// static constexpr ck_tile::index_t K_Warp_Tile = 128;
// #endif
// };
auto create_args(int argc, char* argv[])
{