Mx fp6 flatmm (#3601)

* add fp6 data-type and support sync/async dwordx3 load/store

* clang-format

* pre-commit

* 1st commit

* default mnk pass ut

* fix a distrubution

* fix

* fix bdram distr

* update

* pass ut

* improve perf

* update

* clean code

* resolve copilot comment

* reslove comment

* clang-format

---------

Co-authored-by: ZheWang <zhewan@amd.com>
This commit is contained in:
ZheWang
2026-02-02 16:04:40 +08:00
committed by GitHub
parent 1ae83137eb
commit e6bcd192d4
21 changed files with 761 additions and 136 deletions

View File

@@ -21,7 +21,6 @@ if(has_supported_gpu)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
add_executable(tile_example_flatmm_basic flatmm_basic.cpp)
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})

View File

@@ -179,10 +179,11 @@ auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
const int K = src_lengths[0];
const int N = src_lengths[1];
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
int KPack = 16 * packed_size; // fp4:32 or fp8:16
int NLane = N_Warp_Tile;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
int KPack =
std::is_same_v<dtype, ck_tile::pk_fp6x16_t> ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16
int NLane = N_Warp_Tile;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
@@ -295,7 +296,14 @@ int run_mx_flatmm_example(int argc, char* argv[])
}
else if(mx_prec == "fp6" || mx_prec == "fp6xfp6")
{
throw std::runtime_error("fp6xfp6 is not supported.");
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::pk_fp6x16_t,
ck_tile::pk_fp6x16_t,
ck_tile::fp16_t,
MXfp6_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
{

View File

@@ -44,6 +44,38 @@ struct MXfp4_FlatmmConfig16
static constexpr bool TiledMMAPermuteN = false;
};
struct MXfp6_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct MXfp8_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;

View File

@@ -8,13 +8,14 @@ function(mx_flatmm_instance_generate FILE_LIST)
set(C_LAYOUT ROW)
set(FLATMM_CONFIG_FP4xFP4 "MXfp4_FlatmmConfig16")
set(FLATMM_CONFIG_FP8xFP8 "MXfp8_FlatmmConfig16")
set(FLATMM_CONFIG_FP6xFP6 "MXfp6_FlatmmConfig16")
set(FLATMM_CONFIG_FP8xFP4 "MXf8f4_FlatmmConfig16")
set(FLATMM_CONFIG_FP4xFP8 "MXf4f8_FlatmmConfig16")
# foreach(PERSISTENT false true)
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
foreach(PERSISTENT false)
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP8xFP4 FP4xFP8)
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8)
set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}})
string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE})
list(GET DATA_TYPE_AB 0 A_DATA_TYPE)

View File

@@ -19,6 +19,7 @@
using FP4 = ck_tile::pk_fp4_t;
using FP8 = ck_tile::fp8_t;
using FP6 = ck_tile::pk_fp6x16_t;
using FP16 = ck_tile::fp16_t;
using BF16 = ck_tile::bf16_t;

View File

@@ -68,24 +68,47 @@ int run_mx_flatmm_with_layouts(int argc,
M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::host_tensor_descriptor(
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));
if constexpr(std::is_same_v<ADataType, ck_tile::pk_fp6x16_t>)
{
auto a_buffer_bytes = a_host.get_element_space_size_in_bytes();
auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes();
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b);
std::vector<int8_t> random_bufA(a_buffer_bytes);
std::vector<int8_t> random_bufB(b_buffer_bytes);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<int> dis(1, 4);
if(init_method == 0)
{
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
{
ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b);
for(size_t i = 0; i < a_buffer_bytes; ++i)
random_bufA[i] = static_cast<int8_t>(dis(gen));
for(size_t i = 0; i < b_buffer_bytes; ++i)
random_bufB[i] = static_cast<int8_t>(dis(gen));
memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes);
memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes);
}
else
{
throw std::runtime_error("wrong! Unexpected init_method");
if(init_method == 0)
{
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
{
ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b);
}
else
{
throw std::runtime_error("wrong! Unexpected init_method");
}
}
const auto b_shuffled_host = preShuffleWeight<FlatmmConfig::N_Warp_Tile>(b_origin_host);