Prefetch K for next iteration from LDS in block_gemm_areg_bsmem_creg for gemm-0

This commit is contained in:
Qianfeng Zhang
2025-05-18 13:40:38 +00:00
parent 7c0ac51b4b
commit afd7793e92

View File

@@ -135,11 +135,17 @@ struct BlockGemmARegBSmemCRegV2Hack_0
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
// hot loop:
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor_0 = load_tile(b_warp_windows(nIter)(I0));
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
statically_indexed_array<b_warp_tensor_type, KIterPerWarp> b_warp_tensors;
b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0));
b_warp_tensors[I1] = load_tile(b_warp_windows(nIter)(I1));
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
@@ -150,7 +156,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// warp GEMM
auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensor_0);
auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[I0]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
// write C warp tensor into C block tensor
@@ -162,7 +168,9 @@ struct BlockGemmARegBSmemCRegV2Hack_0
static_for<1, KIterPerWarp, 1>{}([&](auto kIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
if constexpr(kIter < KIterPerWarp - 1)
b_warp_tensors[number<kIter + 1>{}] =
load_tile(b_warp_windows(nIter)(number<kIter + 1>{}));
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
@@ -180,7 +188,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
// write C warp tensor into C block tensor