mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Prefetch K for next iteration from LDS in block_gemm_areg_bsmem_creg for gemm-0
This commit is contained in:
@@ -135,11 +135,17 @@ struct BlockGemmARegBSmemCRegV2Hack_0
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor_0 = load_tile(b_warp_windows(nIter)(I0));
|
||||
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
|
||||
|
||||
statically_indexed_array<b_warp_tensor_type, KIterPerWarp> b_warp_tensors;
|
||||
|
||||
b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0));
|
||||
|
||||
b_warp_tensors[I1] = load_tile(b_warp_windows(nIter)(I1));
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
@@ -150,7 +156,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensor_0);
|
||||
auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[I0]);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -162,7 +168,9 @@ struct BlockGemmARegBSmemCRegV2Hack_0
|
||||
|
||||
static_for<1, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
if constexpr(kIter < KIterPerWarp - 1)
|
||||
b_warp_tensors[number<kIter + 1>{}] =
|
||||
load_tile(b_warp_windows(nIter)(number<kIter + 1>{}));
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
@@ -180,7 +188,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
|
||||
Reference in New Issue
Block a user