mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
update scale-preshuffle for MXF4
This commit is contained in:
@@ -11,7 +11,7 @@ struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
|
||||
@@ -97,17 +97,17 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::MixedPrecFlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::MixedPrecFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
@@ -134,7 +134,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
FlatmmConfig::TiledMMAPermuteN>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::MixedPrecFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -282,10 +282,11 @@ float invoke_mixed_prec_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
// TODO (sizeof(BDataType) / 2)
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K / PackedSize +
|
||||
sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
@@ -366,23 +367,26 @@ auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
|
||||
int n_ = scale.get_lengths()[1];
|
||||
int k_ = scale.get_lengths()[0];
|
||||
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
constexpr int K_Pack = FlatmmConfig::K_Tile / FlatmmConfig::K_Warp_Tile / K_Lane;
|
||||
constexpr int K_Pack = 2; // fixed for mxfp4
|
||||
constexpr int N_Pack = 2; // fixed for mxfp4
|
||||
constexpr int GranularityK = 32; // fixed for mxfp4
|
||||
|
||||
static_assert(sizeof(T) * K_Pack * FlatmmConfig::N_Repeat <= 16, "inefficient pack policy");
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
|
||||
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
|
||||
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
|
||||
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
|
||||
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
k_ / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
n_ / FlatmmConfig::N_Tile,
|
||||
FlatmmConfig::N_Repeat,
|
||||
FlatmmConfig::N_Warp,
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
// return ck_tile::reference_permute(shfl_scale, {0, 3, 5, 2, 4, 1});
|
||||
return ck_tile::reference_permute(shfl_scale, {3, 5, 0, 2, 6, 1, 4});
|
||||
return ck_tile::reference_permute(shfl_scale, {3, 0, 2, 5, 1, 4});
|
||||
}
|
||||
|
||||
#include "run_mixed_prec_flatmm.inc"
|
||||
|
||||
@@ -59,8 +59,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_origin_host);
|
||||
// ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
@@ -166,7 +165,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
|
||||
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
|
||||
|
||||
const float rtol = 1e-3;
|
||||
const float rtol = 5e-3;
|
||||
const float atol = 1e-3;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
|
||||
Reference in New Issue
Block a user