mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Simplify the codes in block_gemm
This commit is contained in:
@@ -87,24 +87,10 @@ struct BlockGemmARegBSmemCRegV2Hack_0
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
#if 0 // FIXME: using array will cause register spill
|
||||
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
|
||||
{b_warp_window_tmp}};
|
||||
|
||||
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
|
||||
{
|
||||
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
}
|
||||
}
|
||||
#else
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
#endif
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
@@ -128,7 +114,6 @@ struct BlockGemmARegBSmemCRegV2Hack_0
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
@@ -136,6 +121,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0
|
||||
|
||||
statically_indexed_array<b_warp_tensor_type, KIterPerWarp> b_warp_tensors;
|
||||
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(nIter)(I0) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(I0),
|
||||
{nIter * NPerBlockPerIter, 0 * KPerBlockPerIter});
|
||||
@@ -143,36 +129,10 @@ struct BlockGemmARegBSmemCRegV2Hack_0
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
b_warp_windows(nIter)(I1) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(I1),
|
||||
{nIter * NPerBlockPerIter, 1 * KPerBlockPerIter});
|
||||
b_warp_tensors[I1] = load_tile(b_warp_windows(nIter)(I1));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[I0]);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
|
||||
static_for<1, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// read B warp tensor from B Block window
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
if constexpr(kIter < KIterPerWarp - 1)
|
||||
{
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(nIter)(number<kIter + 1>{}) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(number<kIter + 1>{}),
|
||||
{nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter});
|
||||
@@ -193,13 +153,22 @@ struct BlockGemmARegBSmemCRegV2Hack_0
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
if constexpr(kIter == 0)
|
||||
{
|
||||
// warp GEMM
|
||||
c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[kIter]);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
}
|
||||
else
|
||||
{
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
};
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
|
||||
@@ -115,42 +115,31 @@ struct BlockGemmARegBSmemCRegV2Hack_1
|
||||
|
||||
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
|
||||
|
||||
statically_indexed_array<statically_indexed_array<b_warp_tensor_type, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensors;
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
statically_indexed_array<b_warp_tensor_type, NIterPerWarp> b_warp_tensors;
|
||||
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(I0)(kIter),
|
||||
{0 * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
b_warp_tensors(I0)(kIter) = load_tile(b_warp_windows(I0)(kIter));
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter < NIterPerWarp - 1)
|
||||
{
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
|
||||
{(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
b_warp_tensors(number<nIter + 1>{})(kIter) =
|
||||
load_tile(b_warp_windows(number<nIter + 1>{})(kIter));
|
||||
});
|
||||
};
|
||||
b_warp_tensors(I0) = load_tile(b_warp_windows(I0)(kIter));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter < NIterPerWarp - 1)
|
||||
{
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
|
||||
{(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
b_warp_tensors(number<nIter + 1>{}) =
|
||||
load_tile(b_warp_windows(number<nIter + 1>{})(kIter));
|
||||
};
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
@@ -158,14 +147,22 @@ struct BlockGemmARegBSmemCRegV2Hack_1
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]);
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -119,42 +119,31 @@ struct BlockGemmARegBSmemTrLoadCRegV2Hack_1
|
||||
|
||||
using b_warp_tensor_type = decltype(load_tile_transpose(b_warp_windows(I0)(I0)));
|
||||
|
||||
statically_indexed_array<statically_indexed_array<b_warp_tensor_type, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensors;
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
statically_indexed_array<b_warp_tensor_type, NIterPerWarp> b_warp_tensors;
|
||||
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(I0)(kIter),
|
||||
{kIter * KPerBlockPerIter, 0 * NPerBlockPerIter});
|
||||
b_warp_tensors(I0)(kIter) = load_tile_transpose(b_warp_windows(I0)(kIter));
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter < NIterPerWarp - 1)
|
||||
{
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
|
||||
{kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter});
|
||||
b_warp_tensors(number<nIter + 1>{})(kIter) =
|
||||
load_tile_transpose(b_warp_windows(number<nIter + 1>{})(kIter));
|
||||
});
|
||||
};
|
||||
b_warp_tensors(I0) = load_tile_transpose(b_warp_windows(I0)(kIter));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter < NIterPerWarp - 1)
|
||||
{
|
||||
// read B warp tensor from B Block window
|
||||
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
|
||||
{kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter});
|
||||
b_warp_tensors(number<nIter + 1>{}) =
|
||||
load_tile_transpose(b_warp_windows(number<nIter + 1>{})(kIter));
|
||||
};
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
@@ -162,15 +151,22 @@ struct BlockGemmARegBSmemTrLoadCRegV2Hack_1
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]);
|
||||
});
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user