Simplify the block_gemm codes

This commit is contained in:
Qianfeng Zhang
2025-12-06 14:33:01 +00:00
parent 25521a7e06
commit 8b85919288
2 changed files with 45 additions and 65 deletions

View File

@@ -128,7 +128,6 @@ struct BlockGemmARegBSmemCRegV2PrefetchK
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
// hot loop:
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
@@ -136,6 +135,7 @@ struct BlockGemmARegBSmemCRegV2PrefetchK
statically_indexed_array<b_warp_tensor_type, KIterPerWarp> b_warp_tensors;
// read B warp tensor from B Block window
b_warp_windows(nIter)(I0) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(I0),
{nIter * NPerBlockPerIter, 0 * KPerBlockPerIter});
@@ -143,36 +143,10 @@ struct BlockGemmARegBSmemCRegV2PrefetchK
__builtin_amdgcn_sched_barrier(0);
b_warp_windows(nIter)(I1) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(I1),
{nIter * NPerBlockPerIter, 1 * KPerBlockPerIter});
b_warp_tensors[I1] = load_tile(b_warp_windows(nIter)(I1));
__builtin_amdgcn_sched_barrier(0);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// warp GEMM
auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[I0]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
static_for<1, KIterPerWarp, 1>{}([&](auto kIter) {
// read B warp tensor from B Block window
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
if constexpr(kIter < KIterPerWarp - 1)
{
// read B warp tensor from B Block window
b_warp_windows(nIter)(number<kIter + 1>{}) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(number<kIter + 1>{}),
{nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter});
@@ -193,13 +167,22 @@ struct BlockGemmARegBSmemCRegV2PrefetchK
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
if constexpr(kIter == 0)
{
// warp GEMM
c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[kIter]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
}
else
{
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
};
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(

View File

@@ -115,42 +115,31 @@ struct BlockGemmARegBSmemCRegV2PrefetchN
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
statically_indexed_array<statically_indexed_array<b_warp_tensor_type, KIterPerWarp>,
NIterPerWarp>
b_warp_tensors;
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
statically_indexed_array<b_warp_tensor_type, NIterPerWarp> b_warp_tensors;
// read B warp tensor from B Block window
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(I0)(kIter),
{0 * NPerBlockPerIter, kIter * KPerBlockPerIter});
b_warp_tensors(I0)(kIter) = load_tile(b_warp_windows(I0)(kIter));
});
b_warp_tensors(I0) = load_tile(b_warp_windows(I0)(kIter));
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter < NIterPerWarp - 1)
{
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter < NIterPerWarp - 1)
{
// read B warp tensor from B Block window
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
{(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter});
b_warp_tensors(number<nIter + 1>{})(kIter) =
b_warp_tensors(number<nIter + 1>{}) =
load_tile(b_warp_windows(number<nIter + 1>{})(kIter));
});
};
};
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(0x0000001);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
@@ -158,14 +147,22 @@ struct BlockGemmARegBSmemCRegV2PrefetchN
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]);
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}