[GEMM] Add define macro for unused a/b blk window

This commit is contained in:
YC Lin
2025-04-23 13:27:32 +00:00
parent 35de33c57b
commit ef085f402d
2 changed files with 25 additions and 33 deletions

View File

@@ -111,8 +111,8 @@ struct BlockGemmASmemBSmemCReg
// C += A * B
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockWindowTmp& a_block_window_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
@@ -127,14 +127,11 @@ struct BlockGemmASmemBSmemCReg
KPerBlock == BlockGemmShape::kK,
"wrong!");
// constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
// using WarpGemm = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
#if !defined(ENABLE_PREFETCH)
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
@@ -142,7 +139,7 @@ struct BlockGemmASmemBSmemCReg
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// construct A-warp-window
// Construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
@@ -158,13 +155,12 @@ struct BlockGemmASmemBSmemCReg
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// construct B-warp-window
// Construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
@@ -180,16 +176,16 @@ struct BlockGemmASmemBSmemCReg
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
// Read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
#if defined(ENABLE_PREFETCH)
#pragma message("local data share prefetch")
@@ -200,7 +196,7 @@ struct BlockGemmASmemBSmemCReg
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
#endif
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
// Read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
#if defined(ENABLE_PREFETCH)
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
@@ -209,17 +205,17 @@ struct BlockGemmASmemBSmemCReg
#else
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
#endif
// read C warp tensor from C block tensor
// Read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
// Warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
// Write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
@@ -231,8 +227,8 @@ struct BlockGemmASmemBSmemCReg
// C = A * B
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
@@ -246,14 +242,11 @@ struct BlockGemmASmemBSmemCReg
KPerBlock == BlockGemmShape::kK,
"wrong!");
// constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
// using WarpGemm = remove_cvref_t<decltype(config.template get(number<0>{}))>;
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
#if !defined(ENABLE_PREFETCH)
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
@@ -261,7 +254,7 @@ struct BlockGemmASmemBSmemCReg
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// construct A-warp-window
// Construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
@@ -277,13 +270,12 @@ struct BlockGemmASmemBSmemCReg
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// construct B-warp-window
// Construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
@@ -299,11 +291,11 @@ struct BlockGemmASmemBSmemCReg
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
@@ -323,10 +315,10 @@ struct BlockGemmASmemBSmemCReg
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
// hot loop:
// Hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
// Read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
#if defined(ENABLE_PREFETCH)
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
@@ -336,7 +328,7 @@ struct BlockGemmASmemBSmemCReg
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
#endif
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
// Read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
#if defined(ENABLE_PREFETCH)
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
@@ -345,10 +337,10 @@ struct BlockGemmASmemBSmemCReg
#else
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
#endif
// read C warp tensor from C block tensor
// Read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// warp GEMM
// Warp GEMM
if constexpr(KIterPerWarp == 0)
{
// c = a * b
@@ -364,7 +356,7 @@ struct BlockGemmASmemBSmemCReg
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
// write C warp tensor into C block tensor
// Write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),

View File

@@ -307,7 +307,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
// Gemm pipeline start
#if defined(ENABLE_PREFETCH)
#pragma message("global prefetch")
// Initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
@@ -342,7 +342,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
// Main body
if(num_loop > 2)
{
index_t i = 0;
index_t iCounter = 0;
do
{
block_sync_lds();