[CK_TILE] Add Flatmm MX FP8 (#3208)

* Use async for flatmm mxfp4

* Fix preshuffle

* Add flatmm mxfp8

* Thanks, Copilot

* Thanks Copilot again~
This commit is contained in:
Yi DING
2025-11-20 10:35:15 +08:00
committed by GitHub
parent 4e49e0228b
commit 47e2ed838e
17 changed files with 698 additions and 595 deletions

View File

@@ -136,7 +136,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run MXFP4_Flatmm kernel " //
std::cout << "Run " << ck_tile::gemm_prec_str<ADataType, BDataType>() << " Flatmm kernel " //
<< " M = " << M << " N = " << N << " K = " << K << " StrideA = " << stride_A
<< " StrideB = " << stride_B << " StrideC = " << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
@@ -172,42 +172,47 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
template <class FlatmmConfig, class IterSrc, class IterDst>
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
template <ck_tile::index_t N_Warp_Tile, typename dtype>
auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
{
int KPack = 16;
int NLane = FlatmmConfig::N_Warp_Tile;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
auto src_lengths = src.get_lengths();
const int K = src_lengths[0];
const int N = src_lengths[1];
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
int KPack = 16 * packed_size; // fp4:32 or fp8:16
int NLane = N_Warp_Tile;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
for(int k = 0; k < K; k += packed_size)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int k0 = k / (KLane * KPack);
int tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
shuffled(outputIndex) = src(k, n);
}
}
return shuffled;
}
template <class FlatmmConfig, bool KLast, typename Src>
auto preShuffleScale(Src& src)
template <class FlatmmConfig, bool KLast, typename dtype>
auto preShuffleScale(ck_tile::HostTensor<dtype>& src)
{
using dtype = typename Src::Data::value_type;
auto src_lengths = src.get_lengths();
const auto MN = KLast ? src_lengths[0] : src_lengths[1];
const auto K = KLast ? src_lengths[1] : src_lengths[0];
@@ -261,7 +266,6 @@ auto preShuffleScale(Src& src)
#include "run_mx_flatmm.inc"
template <typename FlatmmConfig>
int run_mx_flatmm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -278,24 +282,31 @@ int run_mx_flatmm_example(int argc, char* argv[])
if(a_layout == "R" && b_layout == "C")
{
if(mx_prec == "fp4xfp4")
if(mx_prec == "fp4" || mx_prec == "fp4xfp4")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
FlatmmConfig,
MXfp4_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only non-persistent kernels are supported currently!");
}
else if(mx_prec == "fp6xfp6")
else if(mx_prec == "fp6" || mx_prec == "fp6xfp6")
{
throw std::runtime_error("Only support fp4xfp4 now!");
throw std::runtime_error("fp6xfp6 is not supported.");
}
else if(mx_prec == "fp8xfp8")
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
{
throw std::runtime_error("Only support fp4xfp4 now!");
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::fp16_t,
MXfp8_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else
{
@@ -306,7 +317,6 @@ int run_mx_flatmm_example(int argc, char* argv[])
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
return -1;
}
int main(int argc, char* argv[])
@@ -319,7 +329,7 @@ int main(int argc, char* argv[])
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return run_mx_flatmm_example<MXfp4_FlatmmConfig16>(argc, argv);
return run_mx_flatmm_example(argc, argv);
}
else if(warp_tile == 1)
{