update f16xMXF4

This commit is contained in:
Feng Shijie
2025-08-13 16:16:48 +00:00
parent 732ebdee8b
commit 5de6208952
6 changed files with 113 additions and 48 deletions

View File

@@ -97,6 +97,8 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
@@ -129,9 +131,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false,
1,
FlatmmConfig::TiledMMAPermuteN>>;
false, // FixedVectorSize
1, // VectorSizeC
FlatmmConfig::TiledMMAPermuteN,
BlockedXDLN_PerWarp>>;
using Kernel =
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
@@ -211,10 +214,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
}
else
{
// Run(has_hot_loop_,
// tail_number_,
// ck_tile::integral_constant<ck_tile::memory_operation_enum,
// ck_tile::memory_operation_enum::atomic_add>{});
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
@@ -412,17 +415,17 @@ int run_mixed_prec_flatmm_example(int argc, char* argv[])
{
if(persistent_opt == 0)
{
// run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
// ck_tile::pk_fp4_t,
// FlatmmConfig,
// false>(argc, argv, Row{}, Col{}, Row{});
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
false>(argc, argv, Row{}, Col{}, Row{});
}
else
{
// run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
// ck_tile::pk_fp4_t,
// FlatmmConfig,
// true>(argc, argv, Row{}, Col{}, Row{});
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
true>(argc, argv, Row{}, Col{}, Row{});
}
}
else if(mixed_prec == "fp16xfp4")
@@ -434,13 +437,13 @@ int run_mixed_prec_flatmm_example(int argc, char* argv[])
FlatmmConfig,
false>(argc, argv, Row{}, Col{}, Row{});
}
// else
// {
// run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
// ck_tile::pk_fp4_t,
// FlatmmConfig,
// true>(argc, argv, Row{}, Col{}, Row{});
// }
else
{
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
true>(argc, argv, Row{}, Col{}, Row{});
}
}
else
{
@@ -466,10 +469,10 @@ int main(int argc, char* argv[])
{
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
}
// else if(warp_tile == 1)
// {
// return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
// }
else if(warp_tile == 1)
{
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
}
else
{
throw std::runtime_error("Unsupported warp_tile!");

View File

@@ -58,8 +58,8 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
{
@@ -165,8 +165,8 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
const float rtol = 5e-3;
const float atol = 1e-3;
const float rtol = 1e-2;
const float atol = 1e-2;
pass = ck_tile::check_err(
c_rslt_host, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);

View File

@@ -31,10 +31,9 @@ struct e8m0_bexp_t
raw_type data;
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {}
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {}
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale)
: e8m0_bexp_t(static_cast<type>(numeric_utils<float>::get_exponent(scale)))
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale) : data(0)
{
data = numeric_utils<float>::get_exponent(scale);
}
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }

View File

@@ -256,7 +256,7 @@ struct buffer_view<address_space_enum::global,
T* p_data_ = nullptr;
BufferSizeType buffer_size_;
int32x4_t cached_buf_res_;
remove_cvref_t<T> invalid_element_value_ = T{0.f};
remove_cvref_t<T> invalid_element_value_ = T{0};
static constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
@@ -269,7 +269,7 @@ struct buffer_view<address_space_enum::global,
: p_data_{p_data},
buffer_size_{buffer_size / PackedSize},
cached_buf_res_{0},
invalid_element_value_{0.f}
invalid_element_value_{0}
{
}

View File

@@ -27,10 +27,11 @@ template <typename ADataType_,
index_t KPerXdl_,
bool isCTransposed_,
memory_operation_enum MemoryOperation_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false>
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false,
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
struct CShuffleEpilogueProblem
{
using ADataType = remove_cvref_t<ADataType_>;
@@ -53,6 +54,7 @@ struct CShuffleEpilogueProblem
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -89,6 +91,7 @@ struct CShuffleEpilogue
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
@@ -193,7 +196,10 @@ struct CShuffleEpilogue
}
}();
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle =
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
static_assert(NumNXdlPerWavePerShuffle % BlockedXDLN_PerWarp == 0);
static constexpr auto MNPerIterationShuffle = [] {
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
@@ -242,14 +248,31 @@ struct CShuffleEpilogue
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto block_outer_dstr_encoding = [] {
if constexpr(BlockedXDLN_PerWarp == 1)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
}
else
{
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});

View File

@@ -58,6 +58,46 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
// clang-format on
}
template <class ScaleM, class ScaleN>
CK_TILE_HOST static constexpr auto
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
{
if constexpr(UsePersistentKernel)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size,
F16xMXF4FlatmmKernel,
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
// << ", persistent_block_size: " << persistent_block_size
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
assert(kargs.k_batch == 1);
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
}
else
{
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
}
}
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>