update f16xMXF4

This commit is contained in:
Feng Shijie
2025-08-13 16:16:48 +00:00
parent 732ebdee8b
commit 5de6208952
6 changed files with 113 additions and 48 deletions

View File

@@ -27,10 +27,11 @@ template <typename ADataType_,
index_t KPerXdl_,
bool isCTransposed_,
memory_operation_enum MemoryOperation_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false>
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false,
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
struct CShuffleEpilogueProblem
{
using ADataType = remove_cvref_t<ADataType_>;
@@ -53,6 +54,7 @@ struct CShuffleEpilogueProblem
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -89,6 +91,7 @@ struct CShuffleEpilogue
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
@@ -193,7 +196,10 @@ struct CShuffleEpilogue
}
}();
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle =
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
static_assert(NumNXdlPerWavePerShuffle % BlockedXDLN_PerWarp == 0);
static constexpr auto MNPerIterationShuffle = [] {
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
@@ -242,14 +248,31 @@ struct CShuffleEpilogue
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto block_outer_dstr_encoding = [] {
if constexpr(BlockedXDLN_PerWarp == 1)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
}
else
{
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});

View File

@@ -58,6 +58,46 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
// clang-format on
}
template <class ScaleM, class ScaleN>
CK_TILE_HOST static constexpr auto
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
{
if constexpr(UsePersistentKernel)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size,
F16xMXF4FlatmmKernel,
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
// << ", persistent_block_size: " << persistent_block_size
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
assert(kargs.k_batch == 1);
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
}
else
{
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
}
}
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>