[CK_TILE] Add FP8xF4 Flatmm (#3401)

* Refactor policy

* fix a bank conflict

* Enable mixed mx flatmm

* Update
This commit is contained in:
Yi DING
2025-12-17 10:01:48 +08:00
committed by GitHub
parent 3dfa794fab
commit 57e1e4a848
9 changed files with 231 additions and 223 deletions

View File

@@ -148,7 +148,7 @@ auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "32", "m dimension")
.insert("n", "128", "n dimension")
.insert("n", "512", "n dimension")
.insert("k", "256", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
@@ -308,6 +308,28 @@ int run_mx_flatmm_example(int argc, char* argv[])
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else if(mx_prec == "fp8xfp4")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
MXf8f4_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else if(mx_prec == "fp4xfp8")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::fp8_t,
ck_tile::fp16_t,
MXf4f8_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else
{
throw std::runtime_error("Unsupported data_type!");

View File

@@ -76,6 +76,69 @@ struct MXfp8_FlatmmConfig16
static constexpr bool TiledMMAPermuteN = false;
};
struct MXf8f4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct MXf4f8_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,

View File

@@ -6,16 +6,20 @@ function(mx_flatmm_instance_generate FILE_LIST)
set(A_LAYOUT ROW)
set(B_LAYOUT COL)
set(C_LAYOUT ROW)
set(FLATMM_CONFIG_FP4 "MXfp4_FlatmmConfig16")
set(FLATMM_CONFIG_FP8 "MXfp8_FlatmmConfig16")
set(FLATMM_CONFIG_FP4xFP4 "MXfp4_FlatmmConfig16")
set(FLATMM_CONFIG_FP8xFP8 "MXfp8_FlatmmConfig16")
set(FLATMM_CONFIG_FP8xFP4 "MXf8f4_FlatmmConfig16")
set(FLATMM_CONFIG_FP4xFP8 "MXf4f8_FlatmmConfig16")
# foreach(PERSISTENT false true)
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
foreach(PERSISTENT false)
foreach(DATA_TYPE FP4 FP8)
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP8xFP4 FP4xFP8)
set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}})
set(A_DATA_TYPE ${DATA_TYPE})
set(B_DATA_TYPE ${DATA_TYPE})
string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE})
list(GET DATA_TYPE_AB 0 A_DATA_TYPE)
list(GET DATA_TYPE_AB 1 B_DATA_TYPE)
foreach(SPLIT_K false true)
foreach(HAS_HOT_LOOP false true)
foreach(TAIL_NUMBER ODD EVEN)