update scale-preshuffle for MXF4

This commit is contained in:
Feng Shijie
2025-08-13 10:48:53 +00:00
parent edb58d0680
commit 732ebdee8b
6 changed files with 376 additions and 401 deletions

View File

@@ -11,7 +11,7 @@ struct A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;

View File

@@ -97,17 +97,17 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem = ck_tile::MixedPrecFlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::MixedPrecFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
@@ -134,7 +134,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
FlatmmConfig::TiledMMAPermuteN>>;
using Kernel =
ck_tile::MixedPrecFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -282,10 +282,11 @@ float invoke_mixed_prec_flatmm(ck_tile::DeviceMem& a_dev_buf,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
// TODO (sizeof(BDataType) / 2)
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K / PackedSize +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
@@ -366,23 +367,26 @@ auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
int n_ = scale.get_lengths()[1];
int k_ = scale.get_lengths()[0];
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
constexpr int K_Pack = FlatmmConfig::K_Tile / FlatmmConfig::K_Warp_Tile / K_Lane;
constexpr int K_Pack = 2; // fixed for mxfp4
constexpr int N_Pack = 2; // fixed for mxfp4
constexpr int GranularityK = 32; // fixed for mxfp4
static_assert(sizeof(T) * K_Pack * FlatmmConfig::N_Repeat <= 16, "inefficient pack policy");
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
ck_tile::HostTensor<T> shfl_scale({
k_ / K_Pack / K_Lane,
K_Pack,
K_Lane,
n_ / FlatmmConfig::N_Tile,
FlatmmConfig::N_Repeat,
FlatmmConfig::N_Warp,
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
N_Pack,
FlatmmConfig::N_Warp_Tile,
});
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
// return ck_tile::reference_permute(shfl_scale, {0, 3, 5, 2, 4, 1});
return ck_tile::reference_permute(shfl_scale, {3, 5, 0, 2, 6, 1, 4});
return ck_tile::reference_permute(shfl_scale, {3, 0, 2, 5, 1, 4});
}
#include "run_mixed_prec_flatmm.inc"

View File

@@ -59,8 +59,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_origin_host);
// ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
}
else if(init_method == 1)
{
@@ -166,7 +165,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
const float rtol = 1e-3;
const float rtol = 5e-3;
const float atol = 1e-3;
pass = ck_tile::check_err(