mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 (#3287)
* [CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 * typo
This commit is contained in:
@@ -18,21 +18,21 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
{
|
||||
using Underlying = FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using FlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using MXFlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename MXFlatmmPipeline_::BlockGemmShape>; // TileFlatmmShape
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using ALayout = remove_cvref_t<typename MXFlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename MXFlatmmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename MXFlatmmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
|
||||
static constexpr index_t KernelBlockSize = MXFlatmmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = MXFlatmmPipeline::UsePersistentKernel;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
using ADataType = remove_cvref_t<typename MXFlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename MXFlatmmPipeline::BDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
@@ -43,9 +43,9 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
static constexpr int APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr int BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
|
||||
static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
|
||||
static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
|
||||
static constexpr int MXdlPack = MXFlatmmPipeline::MXdlPack;
|
||||
static constexpr int NXdlPack = MXFlatmmPipeline::NXdlPack;
|
||||
static constexpr int KXdlPack = MXFlatmmPipeline::KXdlPack;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -63,7 +63,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
|
||||
return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, MXFlatmmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -123,33 +123,23 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
|
||||
"A tensor for mx must be RowMajor");
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<MXFlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
|
||||
constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock;
|
||||
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
|
||||
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
|
||||
const index_t kFlatN = kargs.N / kNWarpTile;
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
|
||||
static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0,
|
||||
"wrong! vector size for B tensor");
|
||||
auto&& naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
|
||||
@@ -262,20 +252,12 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
|
||||
"A tensor for mx must be RowMajor");
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, MXFlatmmPipeline::kPadK>{});
|
||||
}();
|
||||
|
||||
const auto& b_flat_tensor_view = views.at(I1);
|
||||
@@ -289,14 +271,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
sequence<false, MXFlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
sequence<false, MXFlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
@@ -309,14 +291,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
sequence<false, MXFlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<FlatmmPipeline::kPadM, false>{});
|
||||
sequence<MXFlatmmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -334,26 +316,18 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
|
||||
"A tensor for mx must be RowMajor");
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}();
|
||||
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
make_tuple(number<MXFlatmmPipeline::flatNPerWarp>{},
|
||||
number<MXFlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
@@ -444,14 +418,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
|
||||
a_block_window.get_window_lengths(),
|
||||
a_block_window.get_window_origin(),
|
||||
FlatmmPipeline::GetADramTileDistribution());
|
||||
const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
|
||||
b_flat_block_window,
|
||||
scale_a_block_window,
|
||||
scale_b_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
MXFlatmmPipeline::GetADramTileDistribution());
|
||||
const auto& c_block_tile = MXFlatmmPipeline{}(a_block_window_with_distr,
|
||||
b_flat_block_window,
|
||||
scale_a_block_window,
|
||||
scale_b_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(DoEpiScale)
|
||||
@@ -487,10 +461,10 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
|
||||
splitk_batch_offset.a_k_split_offset / APackedSize;
|
||||
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
splitk_batch_offset.b_k_split_offset / BPackedSize;
|
||||
const auto a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
|
||||
splitk_batch_offset.a_k_split_offset / APackedSize;
|
||||
const auto b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
splitk_batch_offset.b_k_split_offset / BPackedSize;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
@@ -501,7 +475,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
|
||||
@@ -34,13 +34,11 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
|
||||
// using QuantType = BDataType_;
|
||||
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
static constexpr int ScaleGranularityK = 32;
|
||||
|
||||
static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
|
||||
static constexpr int MXdlPack = 2; // it's fixed for fp4
|
||||
static constexpr int NXdlPack = 2; // it's fixed for fp4
|
||||
static constexpr int ContinuousKPerThread = 32; // it's fixed for mx
|
||||
static constexpr int MXdlPack = 2; // it's fixed for mx
|
||||
static constexpr int NXdlPack = 2; // it's fixed for mx
|
||||
static constexpr int KXdlPack = 2;
|
||||
// static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
|
||||
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
|
||||
@@ -63,6 +61,9 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
using BlockFlatmm =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
|
||||
|
||||
@@ -81,8 +82,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t flatKPerWarp = Problem::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = Problem::flatNPerWarp;
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return 32; } /* fixed for fp4 shuffle layout*/
|
||||
static constexpr index_t GetVectorSizeB() { return 32; } /* fixed for fp4 shuffle layout*/
|
||||
@@ -113,22 +114,22 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
|
||||
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
static constexpr index_t KFlatBytesPerBlockPerIter = flatKPerWarp / BPackedSize;
|
||||
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
|
||||
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
// static constexpr index_t WG_AKPacks = WG::kK / APackedSize;
|
||||
// static constexpr index_t WG_BKPacks = WG::kK / BPackedSize;
|
||||
|
||||
static constexpr index_t MXdlPack = Problem::MXdlPack;
|
||||
static constexpr index_t NXdlPack = Problem::NXdlPack;
|
||||
static constexpr index_t KXdlPack = Problem::KXdlPack;
|
||||
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
|
||||
|
||||
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
|
||||
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
|
||||
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType);
|
||||
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
|
||||
? DsReadPreload
|
||||
@@ -562,11 +563,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
// B flat DRAM window for load
|
||||
|
||||
// pingpong buffer for B
|
||||
auto b_flat_dram_window =
|
||||
make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMX_BFlatDramTileDistribution<Problem>());
|
||||
auto b_flat_dram_window = PipelinePolicy::template MakeMX_BFlatBytesDramWindow<Problem>(
|
||||
b_flat_dram_block_window_tmp);
|
||||
auto b_flat_dram_offsets = generate_tuple(
|
||||
[&](auto nIter) {
|
||||
constexpr auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
@@ -621,7 +619,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
scale_b_tile_tensor_ping, scale_b_tile_tensor_pong;
|
||||
|
||||
auto async_load_tile_ = [](auto lds, auto dram) {
|
||||
async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{});
|
||||
async_load_tile(lds, dram, number<-1>{}, true_type{}, true_type{});
|
||||
};
|
||||
|
||||
// HEAD
|
||||
@@ -633,11 +631,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
});
|
||||
// move B window to next flat K
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
|
||||
// prefetch Scale A
|
||||
@@ -698,12 +697,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -739,8 +738,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
|
||||
bit_cast<typename WG::AWarpTensor>(
|
||||
a_warp_tensor(number<AwarpIter>{})),
|
||||
bit_cast<typename WG::BWarpTensor>(
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{})),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0],
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -792,12 +793,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -833,8 +834,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
|
||||
bit_cast<typename WG::AWarpTensor>(
|
||||
a_warp_tensor(number<AwarpIter>{})),
|
||||
bit_cast<typename WG::BWarpTensor>(
|
||||
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{})),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
@@ -897,7 +900,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -932,8 +935,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
|
||||
bit_cast<typename WG::AWarpTensor>(
|
||||
a_warp_tensor(number<AwarpIter>{})),
|
||||
bit_cast<typename WG::BWarpTensor>(
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{})),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -986,8 +991,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
|
||||
bit_cast<typename WG::AWarpTensor>(
|
||||
a_warp_tensor(number<AwarpIter>{})),
|
||||
bit_cast<typename WG::BWarpTensor>(
|
||||
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{})),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
@@ -1029,8 +1036,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
|
||||
bit_cast<typename WG::AWarpTensor>(
|
||||
a_warp_tensor(number<AwarpIter>{})),
|
||||
bit_cast<typename WG::BWarpTensor>(
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{})),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
|
||||
@@ -255,9 +255,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t BPack = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
|
||||
|
||||
@@ -282,21 +284,56 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
|
||||
sequence<K0, K1, K2>>, // 1 64 32
|
||||
sequence<K0, K1, K2 / BPack>>, // 1 64 32
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>,
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
|
||||
sequence<num_access_v, K0, K1, K2>>, // 2 1 64 16
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
|
||||
sequence<num_access_v, K0, K1, K2 / BPack>>, // 2 1 64 16
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 1>, sequence<2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 3>>>{});
|
||||
}
|
||||
|
||||
template <typename Problem, typename WindowTmp>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp)
|
||||
{
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
constexpr auto kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp;
|
||||
constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp;
|
||||
|
||||
static_assert(std::decay_t<decltype(window_tmp)>::get_num_of_dimension() == 2);
|
||||
auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view();
|
||||
const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
|
||||
constexpr auto flat_k_per_block = kKPerBlock * M_Warp_Tile;
|
||||
auto&& byte_tensor_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
flat_n, flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{})),
|
||||
make_tuple(make_pass_through_transform(flat_n),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
|
||||
auto&& byte_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(byte_ptr, byte_tensor_desc);
|
||||
auto&& origin_tmp = window_tmp.get_window_origin();
|
||||
return make_tile_window(
|
||||
byte_tensor_view,
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp / BPackedSize>{}),
|
||||
{origin_tmp[0], origin_tmp[1] / BPackedSize},
|
||||
MakeMX_BFlatBytesDramTileDistribution<Problem>());
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user