mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] Add Flatmm MX FP8 (#3208)
* Use async for flatmm mxfp4 * Fix preshuffle * Add flatmm mxfp8 * Thanks, Copilot * Thanks Copilot again~
This commit is contained in:
@@ -136,7 +136,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run MXFP4_Flatmm kernel " //
|
||||
std::cout << "Run " << ck_tile::gemm_prec_str<ADataType, BDataType>() << " Flatmm kernel " //
|
||||
<< " M = " << M << " N = " << N << " K = " << K << " StrideA = " << stride_A
|
||||
<< " StrideB = " << stride_B << " StrideC = " << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
@@ -172,42 +172,47 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, class IterSrc, class IterDst>
|
||||
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
|
||||
template <ck_tile::index_t N_Warp_Tile, typename dtype>
|
||||
auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = FlatmmConfig::N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
auto src_lengths = src.get_lengths();
|
||||
const int K = src_lengths[0];
|
||||
const int N = src_lengths[1];
|
||||
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
|
||||
int KPack = 16 * packed_size; // fp4:32 or fp8:16
|
||||
int NLane = N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K0 = K / (KLane * KPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
|
||||
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
for(int k = 0; k < K; k += packed_size)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
int k0 = k / (KLane * KPack);
|
||||
int tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
shuffled(outputIndex) = src(k, n);
|
||||
}
|
||||
}
|
||||
return shuffled;
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, bool KLast, typename Src>
|
||||
auto preShuffleScale(Src& src)
|
||||
template <class FlatmmConfig, bool KLast, typename dtype>
|
||||
auto preShuffleScale(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
using dtype = typename Src::Data::value_type;
|
||||
auto src_lengths = src.get_lengths();
|
||||
const auto MN = KLast ? src_lengths[0] : src_lengths[1];
|
||||
const auto K = KLast ? src_lengths[1] : src_lengths[0];
|
||||
@@ -261,7 +266,6 @@ auto preShuffleScale(Src& src)
|
||||
|
||||
#include "run_mx_flatmm.inc"
|
||||
|
||||
template <typename FlatmmConfig>
|
||||
int run_mx_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -278,24 +282,31 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mx_prec == "fp4xfp4")
|
||||
if(mx_prec == "fp4" || mx_prec == "fp4xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::fp16_t,
|
||||
FlatmmConfig,
|
||||
MXfp4_FlatmmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only non-persistent kernels are supported currently!");
|
||||
}
|
||||
else if(mx_prec == "fp6xfp6")
|
||||
else if(mx_prec == "fp6" || mx_prec == "fp6xfp6")
|
||||
{
|
||||
throw std::runtime_error("Only support fp4xfp4 now!");
|
||||
throw std::runtime_error("fp6xfp6 is not supported.");
|
||||
}
|
||||
else if(mx_prec == "fp8xfp8")
|
||||
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
|
||||
{
|
||||
throw std::runtime_error("Only support fp4xfp4 now!");
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp16_t,
|
||||
MXfp8_FlatmmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -306,7 +317,6 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -319,7 +329,7 @@ int main(int argc, char* argv[])
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return run_mx_flatmm_example<MXfp4_FlatmmConfig16>(argc, argv);
|
||||
return run_mx_flatmm_example(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
|
||||
@@ -12,4 +12,87 @@
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include "mxfp4_flatmm.hpp"
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct MXfp4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 512;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct MXfp8_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise,
|
||||
bool Splitk,
|
||||
bool HasHotLoop,
|
||||
ck_tile::TailNumber TailNum>
|
||||
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s);
|
||||
|
||||
@@ -1,24 +1,29 @@
|
||||
function(mx_flatmm_instance_generate FILE_LIST)
|
||||
set(FLATMM_CONFIG MXfp4_FlatmmConfig16)
|
||||
set(A_DATA_TYPE FP4)
|
||||
set(B_DATA_TYPE FP4)
|
||||
set(C_DATA_TYPE FP16)
|
||||
set(A_LAYOUT ROW)
|
||||
set(B_LAYOUT COL)
|
||||
set(C_LAYOUT ROW)
|
||||
set(FLATMM_CONFIG_FP4 "MXfp4_FlatmmConfig16")
|
||||
set(FLATMM_CONFIG_FP8 "MXfp8_FlatmmConfig16")
|
||||
|
||||
# foreach(PERSISTENT false true)
|
||||
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
|
||||
foreach(PERSISTENT false)
|
||||
foreach(SPLIT_K false true)
|
||||
foreach(HAS_HOT_LOOP false true)
|
||||
foreach(TAIL_NUMBER ODD EVEN)
|
||||
set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
|
||||
@ONLY)
|
||||
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
|
||||
foreach(DATA_TYPE FP4 FP8)
|
||||
set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}})
|
||||
set(A_DATA_TYPE ${DATA_TYPE})
|
||||
set(B_DATA_TYPE ${DATA_TYPE})
|
||||
foreach(SPLIT_K false true)
|
||||
foreach(HAS_HOT_LOOP false true)
|
||||
foreach(TAIL_NUMBER ODD EVEN)
|
||||
set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${DATA_TYPE}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
|
||||
string(TOLOWER ${KERNEL_FILE} KERNEL_FILE)
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
|
||||
@ONLY)
|
||||
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
// clang-format on
|
||||
|
||||
using FP4 = ck_tile::pk_fp4_t;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using FP16 = ck_tile::fp16_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct MXfp4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 512;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise,
|
||||
bool Splitk,
|
||||
bool HasHotLoop,
|
||||
ck_tile::TailNumber TailNum>
|
||||
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s);
|
||||
@@ -88,10 +88,7 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
throw std::runtime_error("wrong! Unexpected init_method");
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffled_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffled_host.begin(), N, K);
|
||||
|
||||
const auto b_shuffled_host = preShuffleWeight<FlatmmConfig::N_Warp_Tile>(b_origin_host);
|
||||
const auto scale_a_shuffled = preShuffleScale<FlatmmConfig, true>(scale_a);
|
||||
const auto scale_b_shuffled = preShuffleScale<FlatmmConfig, false>(scale_b);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user