[CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 (#3287)

* [CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2

* typo
This commit is contained in:
Yi DING
2025-12-08 19:20:44 +08:00
committed by GitHub
parent 04612c30ce
commit 878b4e7f46
3 changed files with 141 additions and 121 deletions

View File

@@ -34,13 +34,11 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
// using QuantType = BDataType_;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
static constexpr int ScaleGranularityK = 32;
static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
static constexpr int MXdlPack = 2; // it's fixed for fp4
static constexpr int NXdlPack = 2; // it's fixed for fp4
static constexpr int ContinuousKPerThread = 32; // it's fixed for mx
static constexpr int MXdlPack = 2; // it's fixed for mx
static constexpr int NXdlPack = 2; // it's fixed for mx
static constexpr int KXdlPack = 2;
// static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
@@ -63,6 +61,9 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
using BlockFlatmm =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
@@ -81,8 +82,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t flatKPerWarp = Problem::flatKPerWarp;
static constexpr index_t flatNPerWarp = Problem::flatNPerWarp;
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
static constexpr index_t GetVectorSizeA() { return 32; } /* fixed for fp4 shuffle layout*/
static constexpr index_t GetVectorSizeB() { return 32; } /* fixed for fp4 shuffle layout*/
@@ -113,22 +114,22 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
static constexpr index_t KFlatBytesPerBlockPerIter = flatKPerWarp / BPackedSize;
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
// static constexpr index_t WG_AKPacks = WG::kK / APackedSize;
// static constexpr index_t WG_BKPacks = WG::kK / BPackedSize;
static constexpr index_t MXdlPack = Problem::MXdlPack;
static constexpr index_t NXdlPack = Problem::NXdlPack;
static constexpr index_t KXdlPack = Problem::KXdlPack;
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType);
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
? DsReadPreload
@@ -562,11 +563,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
// B flat DRAM window for load
// pingpong buffer for B
auto b_flat_dram_window =
make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMX_BFlatDramTileDistribution<Problem>());
auto b_flat_dram_window = PipelinePolicy::template MakeMX_BFlatBytesDramWindow<Problem>(
b_flat_dram_block_window_tmp);
auto b_flat_dram_offsets = generate_tuple(
[&](auto nIter) {
constexpr auto packed_n_idx = nIter / number<NXdlPack>{};
@@ -621,7 +619,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
scale_b_tile_tensor_ping, scale_b_tile_tensor_pong;
auto async_load_tile_ = [](auto lds, auto dram) {
async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{});
async_load_tile(lds, dram, number<-1>{}, true_type{}, true_type{});
};
// HEAD
@@ -633,11 +631,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
});
// move B window to next flat K
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
});
// prefetch Scale A
@@ -698,12 +697,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
// move B window to next flat K
if constexpr(kIter == KIterPerWarp - 1)
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
});
});
@@ -739,8 +738,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
bit_cast<typename WG::AWarpTensor>(
a_warp_tensor(number<AwarpIter>{})),
bit_cast<typename WG::BWarpTensor>(
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{})),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0],
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -792,12 +793,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
// move B window to next flat K
if constexpr(kIter == KIterPerWarp - 1)
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
});
});
@@ -833,8 +834,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
bit_cast<typename WG::AWarpTensor>(
a_warp_tensor(number<AwarpIter>{})),
bit_cast<typename WG::BWarpTensor>(
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{})),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
@@ -897,7 +900,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
});
});
@@ -932,8 +935,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
bit_cast<typename WG::AWarpTensor>(
a_warp_tensor(number<AwarpIter>{})),
bit_cast<typename WG::BWarpTensor>(
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{})),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -986,8 +991,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
bit_cast<typename WG::AWarpTensor>(
a_warp_tensor(number<AwarpIter>{})),
bit_cast<typename WG::BWarpTensor>(
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{})),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
@@ -1029,8 +1036,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
bit_cast<typename WG::AWarpTensor>(
a_warp_tensor(number<AwarpIter>{})),
bit_cast<typename WG::BWarpTensor>(
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{})),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)

View File

@@ -255,9 +255,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
using TileShape = typename Problem::BlockGemmShape;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t BPack = numeric_traits<BDataType>::PackedSize;
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
@@ -282,21 +284,56 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<K0, K1, K2>>, // 1 64 32
sequence<K0, K1, K2 / BPack>>, // 1 64 32
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 0>, sequence<1>>,
sequence<2>,
sequence<2>>,
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<num_access_v, K0, K1, K2>>, // 2 1 64 16
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<num_access_v, K0, K1, K2 / BPack>>, // 2 1 64 16
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 1>, sequence<2>>,
sequence<2, 2>,
sequence<0, 3>>>{});
}
template <typename Problem, typename WindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp)
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
constexpr auto kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1);
constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp;
constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp;
static_assert(std::decay_t<decltype(window_tmp)>::get_num_of_dimension() == 2);
auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view();
const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
constexpr auto flat_k_per_block = kKPerBlock * M_Warp_Tile;
auto&& byte_tensor_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(
flat_n, flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{})),
make_tuple(make_pass_through_transform(flat_n),
make_merge_transform_v3_division_mod(make_tuple(
flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
auto&& byte_tensor_view =
make_tensor_view<address_space_enum::global>(byte_ptr, byte_tensor_desc);
auto&& origin_tmp = window_tmp.get_window_origin();
return make_tile_window(
byte_tensor_view,
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp / BPackedSize>{}),
{origin_tmp[0], origin_tmp[1] / BPackedSize},
MakeMX_BFlatBytesDramTileDistribution<Problem>());
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
{