mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
update scale for mxfp4
This commit is contained in:
@@ -6,40 +6,6 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
struct A16W4_FlatmmConfig32
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
|
||||
};
|
||||
|
||||
struct A16W4_FlatmmConfig32_950 : A16W4_FlatmmConfig32
|
||||
{
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
@@ -76,7 +42,6 @@ struct A16W4_FlatmmConfig16
|
||||
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
|
||||
@@ -211,10 +211,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
// Run(has_hot_loop_,
|
||||
// tail_number_,
|
||||
// ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
@@ -327,11 +327,11 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <class IterSrc, class IterDst>
|
||||
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K, int NXdl)
|
||||
template <class FlatmmConfig, class IterSrc, class IterDst>
|
||||
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
int NLane = FlatmmConfig::N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
@@ -359,12 +359,33 @@ void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K, int NXdl)
|
||||
}
|
||||
}
|
||||
|
||||
template <class IterSrc, class IterDst>
|
||||
void preShuffleScale(const IterSrc src, IterDst dst, int N, int K, int NXdl);
|
||||
template <class FlatmmConfig, class T>
|
||||
auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
|
||||
{
|
||||
assert(scale.get_lengths().size() == 2);
|
||||
int n_ = scale.get_lengths()[1];
|
||||
int k_ = scale.get_lengths()[0];
|
||||
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
constexpr int K_Pack = FlatmmConfig::K_Tile / FlatmmConfig::K_Warp_Tile / K_Lane;
|
||||
|
||||
static_assert(sizeof(T) * K_Pack * FlatmmConfig::N_Repeat <= 16, "inefficient pack policy");
|
||||
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
n_ / FlatmmConfig::N_Repeat / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Repeat,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {0, 3, 5, 2, 4, 1});
|
||||
}
|
||||
|
||||
#include "run_mixed_prec_flatmm.inc"
|
||||
|
||||
template <template <typename PrecType> typename FlatmmConfig>
|
||||
template <typename FlatmmConfig>
|
||||
int run_mixed_prec_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -385,33 +406,33 @@ int run_mixed_prec_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig<ck_tile::bf16_t>,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
// run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
// ck_tile::pk_fp4_t,
|
||||
// FlatmmConfig,
|
||||
// false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
// ck_tile::pk_fp4_t,
|
||||
// FlatmmConfig<ck_tile::bf16_t>,
|
||||
// FlatmmConfig,
|
||||
// true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
// if(persistent_opt == 0)
|
||||
// {
|
||||
// run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
// ck_tile::pk_fp4_t,
|
||||
// FlatmmConfig<ck_tile::fp16_t>,
|
||||
// false>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// else
|
||||
// {
|
||||
// run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
// ck_tile::pk_fp4_t,
|
||||
// FlatmmConfig<ck_tile::fp16_t>,
|
||||
// FlatmmConfig,
|
||||
// true>(argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
}
|
||||
@@ -437,11 +458,11 @@ int main(int argc, char* argv[])
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
// else if(warp_tile == 1)
|
||||
// {
|
||||
// return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig32_950>(argc, argv);
|
||||
// return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
// }
|
||||
else
|
||||
{
|
||||
|
||||
@@ -23,6 +23,8 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
using CDataType = PrecActType;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
constexpr int DequantGranularityN = 1;
|
||||
constexpr int DequantGranularityK = 32;
|
||||
|
||||
@@ -50,42 +52,42 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
ck_tile::HostTensor<CDataType> c_rslt_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> weight_dequant_scale(ck_tile::HostTensorDescriptor(
|
||||
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::HostTensorDescriptor(
|
||||
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(weight_dequant_scale);
|
||||
// ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(weight_dequant_scale);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
preShuffleWeight(
|
||||
b_origin_host.begin(), b_shuffle_host.begin(), N, K, FlatmmConfig::N_Warp_Tile);
|
||||
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffle_host.begin(), N, K);
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b_shuffle = preShuffleScale<FlatmmConfig>(scale_b);
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weight_dequant_scale_dev_buf(
|
||||
weight_dequant_scale.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffle.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
c_rslt_host.SetZero();
|
||||
weight_dequant_scale_dev_buf.ToDevice(weight_dequant_scale.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
auto weight_dequant_scale_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK>{
|
||||
static_cast<float*>(weight_dequant_scale_dev_buf.GetDeviceBuffer()),
|
||||
N / DequantGranularityN};
|
||||
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK>{
|
||||
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
|
||||
|
||||
invoke_mixed_prec_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
@@ -97,7 +99,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(weight_dequant_scale_dev_ptr),
|
||||
decltype(scale_b_dev_ptr),
|
||||
UsePersistentKernel>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
@@ -108,7 +110,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
weight_dequant_scale_dev_ptr,
|
||||
scale_b_dev_ptr,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
@@ -126,12 +128,19 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
|
||||
ck_tile::HostTensor<AccDataType> scale_A(
|
||||
ck_tile::HostTensorDescriptor({1, K / DequantGranularityK}, {1, 1}));
|
||||
|
||||
// scaleA = 1 has no effect on the result
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
|
||||
|
||||
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
|
||||
scale_A_dev_buf.ToDevice(scale_A.data());
|
||||
|
||||
// convert scale_b from e8m0 to float
|
||||
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
|
||||
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
|
||||
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
|
||||
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
|
||||
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
|
||||
|
||||
c_gpu_ref_dev_buf.SetZero();
|
||||
ck_tile::reference_blockwise_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
@@ -153,7 +162,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
DequantGranularityN,
|
||||
DequantGranularityK,
|
||||
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(weight_dequant_scale_dev_buf.GetDeviceBuffer()));
|
||||
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user