[CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 (#3287)

* [CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2

* typo
This commit is contained in:
Yi DING
2025-12-08 19:20:44 +08:00
committed by GitHub
parent 04612c30ce
commit 878b4e7f46
3 changed files with 141 additions and 121 deletions

View File

@@ -18,21 +18,21 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
{
using Underlying = FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using MXFlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
using BlockGemmShape =
remove_cvref_t<typename MXFlatmmPipeline_::BlockGemmShape>; // TileFlatmmShape
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using ALayout = remove_cvref_t<typename MXFlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename MXFlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename MXFlatmmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
static constexpr index_t KernelBlockSize = MXFlatmmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = MXFlatmmPipeline::UsePersistentKernel;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
using ADataType = remove_cvref_t<typename MXFlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename MXFlatmmPipeline::BDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
@@ -43,9 +43,9 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
static constexpr int APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr int BPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
static constexpr int MXdlPack = MXFlatmmPipeline::MXdlPack;
static constexpr int NXdlPack = MXFlatmmPipeline::NXdlPack;
static constexpr int KXdlPack = MXFlatmmPipeline::KXdlPack;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -63,7 +63,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, MXFlatmmPipeline::GetName());
// clang-format on
}
@@ -123,33 +123,23 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
const SplitKBatchOffset& splitk_batch_offset)
{
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
"A tensor for mx must be RowMajor");
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<MXFlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}();
constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock;
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
const index_t kFlatN = kargs.N / kNWarpTile;
const auto& b_flat_tensor_view = [&]() {
static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0,
"wrong! vector size for B tensor");
auto&& naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
@@ -262,20 +252,12 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
"A tensor for mx must be RowMajor");
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, MXFlatmmPipeline::kPadK>{});
}();
const auto& b_flat_tensor_view = views.at(I1);
@@ -289,14 +271,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
sequence<false, MXFlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
sequence<false, MXFlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
@@ -309,14 +291,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
sequence<false, MXFlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<FlatmmPipeline::kPadM, false>{});
sequence<MXFlatmmPipeline::kPadM, false>{});
}
}();
@@ -334,26 +316,18 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
const auto& e_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
"A tensor for mx must be RowMajor");
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}();
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
make_tuple(number<MXFlatmmPipeline::flatNPerWarp>{},
number<MXFlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
const auto ds_block_window = generate_tuple(
@@ -444,14 +418,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
a_block_window.get_window_lengths(),
a_block_window.get_window_origin(),
FlatmmPipeline::GetADramTileDistribution());
const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
b_flat_block_window,
scale_a_block_window,
scale_b_block_window,
num_loop,
smem_ptr_ping,
smem_ptr_pong);
MXFlatmmPipeline::GetADramTileDistribution());
const auto& c_block_tile = MXFlatmmPipeline{}(a_block_window_with_distr,
b_flat_block_window,
scale_a_block_window,
scale_b_block_window,
num_loop,
smem_ptr_ping,
smem_ptr_pong);
// Run Epilogue Pipeline
if constexpr(DoEpiScale)
@@ -487,10 +461,10 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
splitk_batch_offset.a_k_split_offset / APackedSize;
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset / BPackedSize;
const auto a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
splitk_batch_offset.a_k_split_offset / APackedSize;
const auto b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset / BPackedSize;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS
@@ -501,7 +475,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,

View File

@@ -34,13 +34,11 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
// using QuantType = BDataType_;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
static constexpr int ScaleGranularityK = 32;
static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
static constexpr int MXdlPack = 2; // it's fixed for fp4
static constexpr int NXdlPack = 2; // it's fixed for fp4
static constexpr int ContinuousKPerThread = 32; // it's fixed for mx
static constexpr int MXdlPack = 2; // it's fixed for mx
static constexpr int NXdlPack = 2; // it's fixed for mx
static constexpr int KXdlPack = 2;
// static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
@@ -63,6 +61,9 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
using BlockFlatmm =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
@@ -81,8 +82,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t flatKPerWarp = Problem::flatKPerWarp;
static constexpr index_t flatNPerWarp = Problem::flatNPerWarp;
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
static constexpr index_t GetVectorSizeA() { return 32; } /* fixed for fp4 shuffle layout*/
static constexpr index_t GetVectorSizeB() { return 32; } /* fixed for fp4 shuffle layout*/
@@ -113,22 +114,22 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
static constexpr index_t KFlatBytesPerBlockPerIter = flatKPerWarp / BPackedSize;
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
// static constexpr index_t WG_AKPacks = WG::kK / APackedSize;
// static constexpr index_t WG_BKPacks = WG::kK / BPackedSize;
static constexpr index_t MXdlPack = Problem::MXdlPack;
static constexpr index_t NXdlPack = Problem::NXdlPack;
static constexpr index_t KXdlPack = Problem::KXdlPack;
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType);
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
? DsReadPreload
@@ -562,11 +563,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
// B flat DRAM window for load
// pingpong buffer for B
auto b_flat_dram_window =
make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMX_BFlatDramTileDistribution<Problem>());
auto b_flat_dram_window = PipelinePolicy::template MakeMX_BFlatBytesDramWindow<Problem>(
b_flat_dram_block_window_tmp);
auto b_flat_dram_offsets = generate_tuple(
[&](auto nIter) {
constexpr auto packed_n_idx = nIter / number<NXdlPack>{};
@@ -621,7 +619,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
scale_b_tile_tensor_ping, scale_b_tile_tensor_pong;
auto async_load_tile_ = [](auto lds, auto dram) {
async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{});
async_load_tile(lds, dram, number<-1>{}, true_type{}, true_type{});
};
// HEAD
@@ -633,11 +631,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
});
// move B window to next flat K
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
});
// prefetch Scale A
@@ -698,12 +697,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
// move B window to next flat K
if constexpr(kIter == KIterPerWarp - 1)
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
});
});
@@ -739,8 +738,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
bit_cast<typename WG::AWarpTensor>(
a_warp_tensor(number<AwarpIter>{})),
bit_cast<typename WG::BWarpTensor>(
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{})),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0],
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -792,12 +793,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
// move B window to next flat K
if constexpr(kIter == KIterPerWarp - 1)
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
});
});
@@ -833,8 +834,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
bit_cast<typename WG::AWarpTensor>(
a_warp_tensor(number<AwarpIter>{})),
bit_cast<typename WG::BWarpTensor>(
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{})),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
@@ -897,7 +900,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
});
});
@@ -932,8 +935,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
bit_cast<typename WG::AWarpTensor>(
a_warp_tensor(number<AwarpIter>{})),
bit_cast<typename WG::BWarpTensor>(
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{})),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -986,8 +991,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
bit_cast<typename WG::AWarpTensor>(
a_warp_tensor(number<AwarpIter>{})),
bit_cast<typename WG::BWarpTensor>(
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{})),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
@@ -1029,8 +1036,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
bit_cast<typename WG::AWarpTensor>(
a_warp_tensor(number<AwarpIter>{})),
bit_cast<typename WG::BWarpTensor>(
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{})),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)

View File

@@ -255,9 +255,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
using TileShape = typename Problem::BlockGemmShape;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t BPack = numeric_traits<BDataType>::PackedSize;
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
@@ -282,21 +284,56 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<K0, K1, K2>>, // 1 64 32
sequence<K0, K1, K2 / BPack>>, // 1 64 32
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 0>, sequence<1>>,
sequence<2>,
sequence<2>>,
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<num_access_v, K0, K1, K2>>, // 2 1 64 16
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<num_access_v, K0, K1, K2 / BPack>>, // 2 1 64 16
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 1>, sequence<2>>,
sequence<2, 2>,
sequence<0, 3>>>{});
}
template <typename Problem, typename WindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp)
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
constexpr auto kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1);
constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp;
constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp;
static_assert(std::decay_t<decltype(window_tmp)>::get_num_of_dimension() == 2);
auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view();
const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
constexpr auto flat_k_per_block = kKPerBlock * M_Warp_Tile;
auto&& byte_tensor_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(
flat_n, flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{})),
make_tuple(make_pass_through_transform(flat_n),
make_merge_transform_v3_division_mod(make_tuple(
flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
auto&& byte_tensor_view =
make_tensor_view<address_space_enum::global>(byte_ptr, byte_tensor_desc);
auto&& origin_tmp = window_tmp.get_window_origin();
return make_tile_window(
byte_tensor_view,
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp / BPackedSize>{}),
{origin_tmp[0], origin_tmp[1] / BPackedSize},
MakeMX_BFlatBytesDramTileDistribution<Problem>());
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
{