[CK_TILE] Add Flatmm MX FP8 (#3208)

* Use async for flatmm mxfp4

* Fix preshuffle

* Add flatmm mxfp8

* Thanks, Copilot

* Thanks Copilot again~
This commit is contained in:
Yi DING
2025-11-20 10:35:15 +08:00
committed by GitHub
parent 4e49e0228b
commit 47e2ed838e
17 changed files with 698 additions and 595 deletions

View File

@@ -136,7 +136,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run MXFP4_Flatmm kernel " //
std::cout << "Run " << ck_tile::gemm_prec_str<ADataType, BDataType>() << " Flatmm kernel " //
<< " M = " << M << " N = " << N << " K = " << K << " StrideA = " << stride_A
<< " StrideB = " << stride_B << " StrideC = " << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
@@ -172,42 +172,47 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
template <class FlatmmConfig, class IterSrc, class IterDst>
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
template <ck_tile::index_t N_Warp_Tile, typename dtype>
auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
{
int KPack = 16;
int NLane = FlatmmConfig::N_Warp_Tile;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
auto src_lengths = src.get_lengths();
const int K = src_lengths[0];
const int N = src_lengths[1];
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
int KPack = 16 * packed_size; // fp4:32 or fp8:16
int NLane = N_Warp_Tile;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
for(int k = 0; k < K; k += packed_size)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int k0 = k / (KLane * KPack);
int tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
shuffled(outputIndex) = src(k, n);
}
}
return shuffled;
}
template <class FlatmmConfig, bool KLast, typename Src>
auto preShuffleScale(Src& src)
template <class FlatmmConfig, bool KLast, typename dtype>
auto preShuffleScale(ck_tile::HostTensor<dtype>& src)
{
using dtype = typename Src::Data::value_type;
auto src_lengths = src.get_lengths();
const auto MN = KLast ? src_lengths[0] : src_lengths[1];
const auto K = KLast ? src_lengths[1] : src_lengths[0];
@@ -261,7 +266,6 @@ auto preShuffleScale(Src& src)
#include "run_mx_flatmm.inc"
template <typename FlatmmConfig>
int run_mx_flatmm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -278,24 +282,31 @@ int run_mx_flatmm_example(int argc, char* argv[])
if(a_layout == "R" && b_layout == "C")
{
if(mx_prec == "fp4xfp4")
if(mx_prec == "fp4" || mx_prec == "fp4xfp4")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
FlatmmConfig,
MXfp4_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only non-persistent kernels are supported currently!");
}
else if(mx_prec == "fp6xfp6")
else if(mx_prec == "fp6" || mx_prec == "fp6xfp6")
{
throw std::runtime_error("Only support fp4xfp4 now!");
throw std::runtime_error("fp6xfp6 is not supported.");
}
else if(mx_prec == "fp8xfp8")
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
{
throw std::runtime_error("Only support fp4xfp4 now!");
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::fp16_t,
MXfp8_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else
{
@@ -306,7 +317,6 @@ int run_mx_flatmm_example(int argc, char* argv[])
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
return -1;
}
int main(int argc, char* argv[])
@@ -319,7 +329,7 @@ int main(int argc, char* argv[])
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return run_mx_flatmm_example<MXfp4_FlatmmConfig16>(argc, argv);
return run_mx_flatmm_example(argc, argv);
}
else if(warp_tile == 1)
{

View File

@@ -12,4 +12,87 @@
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "mxfp4_flatmm.hpp"
// GEMM config with 16x16 warp tile
struct MXfp4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 512;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct MXfp8_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise,
bool Splitk,
bool HasHotLoop,
ck_tile::TailNumber TailNum>
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);

View File

@@ -1,24 +1,29 @@
function(mx_flatmm_instance_generate FILE_LIST)
set(FLATMM_CONFIG MXfp4_FlatmmConfig16)
set(A_DATA_TYPE FP4)
set(B_DATA_TYPE FP4)
set(C_DATA_TYPE FP16)
set(A_LAYOUT ROW)
set(B_LAYOUT COL)
set(C_LAYOUT ROW)
set(FLATMM_CONFIG_FP4 "MXfp4_FlatmmConfig16")
set(FLATMM_CONFIG_FP8 "MXfp8_FlatmmConfig16")
# foreach(PERSISTENT false true)
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
foreach(PERSISTENT false)
foreach(SPLIT_K false true)
foreach(HAS_HOT_LOOP false true)
foreach(TAIL_NUMBER ODD EVEN)
set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
@ONLY)
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
foreach(DATA_TYPE FP4 FP8)
set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}})
set(A_DATA_TYPE ${DATA_TYPE})
set(B_DATA_TYPE ${DATA_TYPE})
foreach(SPLIT_K false true)
foreach(HAS_HOT_LOOP false true)
foreach(TAIL_NUMBER ODD EVEN)
set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${DATA_TYPE}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
string(TOLOWER ${KERNEL_FILE} KERNEL_FILE)
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
@ONLY)
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
endforeach()
endforeach()
endforeach()
endforeach()

View File

@@ -18,6 +18,7 @@
// clang-format on
using FP4 = ck_tile::pk_fp4_t;
using FP8 = ck_tile::fp8_t;
using FP16 = ck_tile::fp16_t;
using BF16 = ck_tile::bf16_t;

View File

@@ -1,60 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
// GEMM config with 16x16 warp tile
struct MXfp4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 512;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise,
bool Splitk,
bool HasHotLoop,
ck_tile::TailNumber TailNum>
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);

View File

@@ -88,10 +88,7 @@ int run_mx_flatmm_with_layouts(int argc,
throw std::runtime_error("wrong! Unexpected init_method");
}
ck_tile::HostTensor<BDataType> b_shuffled_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffled_host.begin(), N, K);
const auto b_shuffled_host = preShuffleWeight<FlatmmConfig::N_Warp_Tile>(b_origin_host);
const auto scale_a_shuffled = preShuffleScale<FlatmmConfig, true>(scale_a);
const auto scale_b_shuffled = preShuffleScale<FlatmmConfig, false>(scale_b);