[CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 (#3287)

* [CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2

* typo
This commit is contained in:
Yi DING
2025-12-08 19:20:44 +08:00
committed by GitHub
parent 04612c30ce
commit 878b4e7f46
3 changed files with 141 additions and 121 deletions

View File

@@ -18,21 +18,21 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
{
using Underlying = FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using MXFlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
using BlockGemmShape =
remove_cvref_t<typename MXFlatmmPipeline_::BlockGemmShape>; // TileFlatmmShape
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using ALayout = remove_cvref_t<typename MXFlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename MXFlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename MXFlatmmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
static constexpr index_t KernelBlockSize = MXFlatmmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = MXFlatmmPipeline::UsePersistentKernel;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
using ADataType = remove_cvref_t<typename MXFlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename MXFlatmmPipeline::BDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
@@ -43,9 +43,9 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
static constexpr int APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr int BPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
static constexpr int MXdlPack = MXFlatmmPipeline::MXdlPack;
static constexpr int NXdlPack = MXFlatmmPipeline::NXdlPack;
static constexpr int KXdlPack = MXFlatmmPipeline::KXdlPack;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -63,7 +63,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, MXFlatmmPipeline::GetName());
// clang-format on
}
@@ -123,33 +123,23 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
const SplitKBatchOffset& splitk_batch_offset)
{
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
"A tensor for mx must be RowMajor");
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<MXFlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}();
constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock;
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
const index_t kFlatN = kargs.N / kNWarpTile;
const auto& b_flat_tensor_view = [&]() {
static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0,
"wrong! vector size for B tensor");
auto&& naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
@@ -262,20 +252,12 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
"A tensor for mx must be RowMajor");
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, MXFlatmmPipeline::kPadK>{});
}();
const auto& b_flat_tensor_view = views.at(I1);
@@ -289,14 +271,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
sequence<false, MXFlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
sequence<false, MXFlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
@@ -309,14 +291,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
sequence<false, MXFlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<FlatmmPipeline::kPadM, false>{});
sequence<MXFlatmmPipeline::kPadM, false>{});
}
}();
@@ -334,26 +316,18 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
const auto& e_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
"A tensor for mx must be RowMajor");
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}();
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
make_tuple(number<MXFlatmmPipeline::flatNPerWarp>{},
number<MXFlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
const auto ds_block_window = generate_tuple(
@@ -444,14 +418,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
a_block_window.get_window_lengths(),
a_block_window.get_window_origin(),
FlatmmPipeline::GetADramTileDistribution());
const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
b_flat_block_window,
scale_a_block_window,
scale_b_block_window,
num_loop,
smem_ptr_ping,
smem_ptr_pong);
MXFlatmmPipeline::GetADramTileDistribution());
const auto& c_block_tile = MXFlatmmPipeline{}(a_block_window_with_distr,
b_flat_block_window,
scale_a_block_window,
scale_b_block_window,
num_loop,
smem_ptr_ping,
smem_ptr_pong);
// Run Epilogue Pipeline
if constexpr(DoEpiScale)
@@ -487,10 +461,10 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
splitk_batch_offset.a_k_split_offset / APackedSize;
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset / BPackedSize;
const auto a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
splitk_batch_offset.a_k_split_offset / APackedSize;
const auto b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset / BPackedSize;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS
@@ -501,7 +475,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,