mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 (#3287)
* [CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 * typo
This commit is contained in:
@@ -34,13 +34,11 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
|
||||
// using QuantType = BDataType_;
|
||||
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
static constexpr int ScaleGranularityK = 32;
|
||||
|
||||
static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
|
||||
static constexpr int MXdlPack = 2; // it's fixed for fp4
|
||||
static constexpr int NXdlPack = 2; // it's fixed for fp4
|
||||
static constexpr int ContinuousKPerThread = 32; // it's fixed for mx
|
||||
static constexpr int MXdlPack = 2; // it's fixed for mx
|
||||
static constexpr int NXdlPack = 2; // it's fixed for mx
|
||||
static constexpr int KXdlPack = 2;
|
||||
// static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
|
||||
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
|
||||
@@ -63,6 +61,9 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
using BlockFlatmm =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
|
||||
|
||||
@@ -81,8 +82,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t flatKPerWarp = Problem::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = Problem::flatNPerWarp;
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return 32; } /* fixed for fp4 shuffle layout*/
|
||||
static constexpr index_t GetVectorSizeB() { return 32; } /* fixed for fp4 shuffle layout*/
|
||||
@@ -113,22 +114,22 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
|
||||
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
static constexpr index_t KFlatBytesPerBlockPerIter = flatKPerWarp / BPackedSize;
|
||||
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
|
||||
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
// static constexpr index_t WG_AKPacks = WG::kK / APackedSize;
|
||||
// static constexpr index_t WG_BKPacks = WG::kK / BPackedSize;
|
||||
|
||||
static constexpr index_t MXdlPack = Problem::MXdlPack;
|
||||
static constexpr index_t NXdlPack = Problem::NXdlPack;
|
||||
static constexpr index_t KXdlPack = Problem::KXdlPack;
|
||||
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
|
||||
|
||||
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
|
||||
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
|
||||
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType);
|
||||
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
|
||||
? DsReadPreload
|
||||
@@ -562,11 +563,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
// B flat DRAM window for load
|
||||
|
||||
// pingpong buffer for B
|
||||
auto b_flat_dram_window =
|
||||
make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMX_BFlatDramTileDistribution<Problem>());
|
||||
auto b_flat_dram_window = PipelinePolicy::template MakeMX_BFlatBytesDramWindow<Problem>(
|
||||
b_flat_dram_block_window_tmp);
|
||||
auto b_flat_dram_offsets = generate_tuple(
|
||||
[&](auto nIter) {
|
||||
constexpr auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
@@ -621,7 +619,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
scale_b_tile_tensor_ping, scale_b_tile_tensor_pong;
|
||||
|
||||
auto async_load_tile_ = [](auto lds, auto dram) {
|
||||
async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{});
|
||||
async_load_tile(lds, dram, number<-1>{}, true_type{}, true_type{});
|
||||
};
|
||||
|
||||
// HEAD
|
||||
@@ -633,11 +631,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
});
|
||||
// move B window to next flat K
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
|
||||
// prefetch Scale A
|
||||
@@ -698,12 +697,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -739,8 +738,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
|
||||
bit_cast<typename WG::AWarpTensor>(
|
||||
a_warp_tensor(number<AwarpIter>{})),
|
||||
bit_cast<typename WG::BWarpTensor>(
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{})),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0],
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -792,12 +793,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -833,8 +834,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
|
||||
bit_cast<typename WG::AWarpTensor>(
|
||||
a_warp_tensor(number<AwarpIter>{})),
|
||||
bit_cast<typename WG::BWarpTensor>(
|
||||
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{})),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
@@ -897,7 +900,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -932,8 +935,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
|
||||
bit_cast<typename WG::AWarpTensor>(
|
||||
a_warp_tensor(number<AwarpIter>{})),
|
||||
bit_cast<typename WG::BWarpTensor>(
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{})),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -986,8 +991,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
|
||||
bit_cast<typename WG::AWarpTensor>(
|
||||
a_warp_tensor(number<AwarpIter>{})),
|
||||
bit_cast<typename WG::BWarpTensor>(
|
||||
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{})),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
@@ -1029,8 +1036,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
|
||||
bit_cast<typename WG::AWarpTensor>(
|
||||
a_warp_tensor(number<AwarpIter>{})),
|
||||
bit_cast<typename WG::BWarpTensor>(
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{})),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
|
||||
@@ -255,9 +255,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t BPack = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
|
||||
|
||||
@@ -282,21 +284,56 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
|
||||
sequence<K0, K1, K2>>, // 1 64 32
|
||||
sequence<K0, K1, K2 / BPack>>, // 1 64 32
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>,
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
|
||||
sequence<num_access_v, K0, K1, K2>>, // 2 1 64 16
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
|
||||
sequence<num_access_v, K0, K1, K2 / BPack>>, // 2 1 64 16
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 1>, sequence<2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 3>>>{});
|
||||
}
|
||||
|
||||
template <typename Problem, typename WindowTmp>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp)
|
||||
{
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
constexpr auto kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp;
|
||||
constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp;
|
||||
|
||||
static_assert(std::decay_t<decltype(window_tmp)>::get_num_of_dimension() == 2);
|
||||
auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view();
|
||||
const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
|
||||
constexpr auto flat_k_per_block = kKPerBlock * M_Warp_Tile;
|
||||
auto&& byte_tensor_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
flat_n, flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{})),
|
||||
make_tuple(make_pass_through_transform(flat_n),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
|
||||
auto&& byte_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(byte_ptr, byte_tensor_desc);
|
||||
auto&& origin_tmp = window_tmp.get_window_origin();
|
||||
return make_tile_window(
|
||||
byte_tensor_view,
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp / BPackedSize>{}),
|
||||
{origin_tmp[0], origin_tmp[1] / BPackedSize},
|
||||
MakeMX_BFlatBytesDramTileDistribution<Problem>());
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user