Flatmm merge (#2168)

* sync with function interface of cshuffleepiloge,fix flatmm build fail

* move code from solin/flatmm which add mfma16*16*32fp8 and optimize flatmm

---------

Co-authored-by: solin <bingzhou@amd.com>
This commit is contained in:
BingYuan.Zhou
2025-05-08 12:59:57 +08:00
committed by GitHub
parent c7b8e86e34
commit 6a3960c1e1
11 changed files with 552 additions and 192 deletions

View File

@@ -66,76 +66,24 @@ struct BlockFlatmmASmemBSmemCRegV1
}
// C += A * B
template <typename CBlockTensor, typename ABlockWindow, typename BFlatBlockWindow>
template <typename CBlockTensor, typename ABlockWindow, typename BFlatBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockWindow& a_block_window,
const BFlatBlockWindow& b_flat_block_window) const
ABlockWindow& a_warp_windows,
BFlatBlockTensor& b_warp_tensor) const
{
static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> &&
std::is_same_v<BDataType, typename BFlatBlockWindow::DataType> &&
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && KPerBlock == BlockGemmShape::kK, "wrong!");
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp =
BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
constexpr index_t NFlatPerBlockPerIter = BlockGemmShape::flatNPerWarp;
constexpr index_t KFlatPerBlockPerIter = BlockGemmShape::flatKPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
// construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// construct Bflat-warp-window
auto b_flat_warp_windows_tmp = b_flat_block_window;
statically_indexed_array<
statically_indexed_array<decltype(b_flat_warp_windows_tmp), KIterPerWarp>,
NIterPerWarp>
b_flat_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_warp_windows(nIter)(kIter) = b_flat_warp_windows_tmp;
move_tile_window(b_flat_warp_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
});
});
// auto b_warp_windows = b_origin_warp_windows;
auto b_warp_windows = b_flat_warp_windows;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
@@ -150,9 +98,6 @@ struct BlockFlatmmASmemBSmemCRegV1
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -161,7 +106,7 @@ struct BlockFlatmmASmemBSmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
@@ -172,16 +117,6 @@ struct BlockFlatmmASmemBSmemCRegV1
});
});
}
// C = A * B
template <typename ABlockTensorTmp, typename BFlatBlockWindow>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BFlatBlockWindow& b_flat_block_window) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_flat_block_window);
return c_block_tensor;
}
};
} // namespace ck_tile