mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
small refactor
This commit is contained in:
@@ -92,19 +92,11 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
|
||||
|
||||
template <index_t lgkmcnt = 0>
|
||||
CK_TILE_DEVICE void block_sync_lds()
|
||||
{
|
||||
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
// asm volatile("\
|
||||
// s_waitcnt lgkmcnt(0) \n \
|
||||
// s_barrier \
|
||||
// " ::);
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
__builtin_amdgcn_s_waitcnt(CK_TILE_S_CNT_MAX & CK_TILE_LGKMCNT(lgkmcnt));
|
||||
__builtin_amdgcn_s_barrier();
|
||||
#else
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
|
||||
|
||||
@@ -831,7 +831,8 @@ struct FmhaFwdDecodeKernel
|
||||
|
||||
// TODO: Add kVHeadDim
|
||||
// TrLoad Performed in 16x4/16x8/16x16 unit, the fast dimension is 16 elements
|
||||
constexpr auto XorGroupSize = FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
constexpr auto XorGroupSize =
|
||||
FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
|
||||
const auto v_dram_unmerged = transform_tensor_view(
|
||||
v_dram_pad,
|
||||
|
||||
@@ -758,41 +758,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution<Problem, true>());
|
||||
|
||||
auto k_lds_write_view =
|
||||
make_tuple(make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType* __restrict__>(smem_ptrk0),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true>()),
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType* __restrict__>(smem_ptrk1),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true>()));
|
||||
auto k_lds_write_view = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType* __restrict__>(smem_ptrk0),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true>());
|
||||
|
||||
auto k_lds_read_view =
|
||||
make_tuple(make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType* __restrict__>(smem_ptrk0),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true, true>()),
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType* __restrict__>(smem_ptrk1),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true, true>()));
|
||||
auto k_lds_read_view = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType* __restrict__>(smem_ptrk0),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true, true>());
|
||||
|
||||
auto k_lds_write_windows =
|
||||
make_tuple(make_tile_window(
|
||||
k_lds_write_view.at(I0),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true>().get_lengths(),
|
||||
{0, 0}),
|
||||
make_tile_window(
|
||||
k_lds_write_view.at(I1),
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true>().get_lengths(),
|
||||
{0, 0}));
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds_write_view,
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0});
|
||||
|
||||
auto k_lds_read_windows =
|
||||
make_tuple(make_tile_window(k_lds_read_view.at(I0),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKRegTileDistribution<Problem>()),
|
||||
make_tile_window(k_lds_read_view.at(I1),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKRegTileDistribution<Problem>()));
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_read_view,
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKRegTileDistribution<Problem>());
|
||||
|
||||
// S tile in LDS
|
||||
auto s_lds = make_tensor_view<address_space_enum::lds>(
|
||||
@@ -811,39 +794,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto v_lds_write_view = make_tuple(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>()),
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv1)),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>()));
|
||||
auto v_lds_write_view = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_read_view = make_tuple(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem, true>()),
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv1)),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem, true>()));
|
||||
auto v_lds_read_view = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem, true>());
|
||||
|
||||
auto v_lds_write_windows = make_tuple(
|
||||
make_tile_window(v_lds_write_view.at(I0),
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds_write_view,
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0}),
|
||||
make_tile_window(v_lds_write_view.at(I1),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0}));
|
||||
{0, 0});
|
||||
|
||||
auto v_lds_read_windows =
|
||||
make_tuple(make_tile_window(v_lds_read_view.at(I0),
|
||||
make_tuple(number<kK1>{}, number<kN1>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeVRegTileDistribution<Problem>()),
|
||||
make_tile_window(v_lds_read_view.at(I1),
|
||||
make_tuple(number<kK1>{}, number<kN1>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeVRegTileDistribution<Problem>()));
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_read_view,
|
||||
make_tuple(number<kK1>{}, number<kN1>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeVRegTileDistribution<Problem>());
|
||||
|
||||
// block_sync_lds_direct_load<0>();
|
||||
// auto q_tile = load_tile(q_lds_read_window);
|
||||
@@ -857,31 +825,41 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
static_assert(1 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
async_load_tile(k_lds_write_windows.at(I0), k_dram_window);
|
||||
async_load_tile(v_lds_write_windows.at(I0), v_dram_window);
|
||||
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
async_load_tile(v_lds_write_window, v_dram_window);
|
||||
|
||||
move_tile_window(k_dram_window, {kN0, 0});
|
||||
async_load_tile(k_lds_write_windows.at(I1), k_dram_window);
|
||||
k_lds_write_window.set_bottom_tensor_view_data_ptr(
|
||||
static_cast<KDataType* __restrict__>(smem_ptrk1));
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
|
||||
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
|
||||
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
constexpr index_t k_lds_insts = k_lds_read_window.get_num_of_access();
|
||||
constexpr index_t v_lds_insts = v_lds_read_window.get_num_of_access();
|
||||
|
||||
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
|
||||
auto k_tile = load_tile(k_lds_read_windows.at(I0));
|
||||
auto k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto mainloop = [&](index_t cur_loop) {
|
||||
|
||||
auto k_lds_write_window = (cur_loop%2 == 0)? k_lds_write_windows.at(I0) : k_lds_write_windows.at(I1);
|
||||
auto k_lds_read_window_cur = (cur_loop%2 == 0)? k_lds_read_windows.at(I0) : k_lds_read_windows.at(I1);
|
||||
auto k_lds_read_window_next = (cur_loop%2 == 0)? k_lds_read_windows.at(I1) : k_lds_read_windows.at(I0);
|
||||
auto v_lds_write_window = (cur_loop%2 == 0)? v_lds_write_windows.at(I1) : v_lds_write_windows.at(I0);
|
||||
auto v_lds_read_window = (cur_loop%2 == 0)? v_lds_read_windows.at(I0) : v_lds_read_windows.at(I1);
|
||||
const bool is_even_loop = (cur_loop % 2 == 0);
|
||||
|
||||
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk1);
|
||||
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk0);
|
||||
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv0);
|
||||
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv1);
|
||||
|
||||
// move V tile windows
|
||||
block_sync_lds();
|
||||
block_sync_lds<k_lds_insts>();
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
v_lds_write_window.set_bottom_tensor_view_data_ptr(v_lds_write_ptr);
|
||||
async_load_tile(v_lds_write_window, v_dram_window);
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
@@ -891,19 +869,19 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
// loop over along the [K]ey head dimension
|
||||
move_tile_window(k_lds_read_window_cur, {0, kK0});
|
||||
auto k_tile_switch = load_tile(k_lds_read_window_cur);
|
||||
move_tile_window(k_lds_read_window, {0, kK0});
|
||||
auto k_tile_switch = load_tile(k_lds_read_window);
|
||||
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_tile);
|
||||
|
||||
|
||||
k_tile = k_tile_switch;
|
||||
});
|
||||
// move back to the origin
|
||||
move_tile_window(k_lds_read_window_cur, {0, -kK0 * (k0_loops - 1)});
|
||||
move_tile_window(k_lds_read_window, {0, -kK0 * (k0_loops - 1)});
|
||||
}
|
||||
|
||||
gemm_0(s_acc,
|
||||
@@ -911,21 +889,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_tile);
|
||||
|
||||
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
|
||||
auto v_tile = load_tile_transpose(v_lds_read_window);
|
||||
|
||||
static_for<0, 14, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
|
||||
});
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ
|
||||
});
|
||||
block_sync_lds_direct_load<k_vmem_insts + v_vmem_insts>();
|
||||
v_lds_read_window.set_bottom_tensor_view_data_ptr(v_lds_read_ptr);
|
||||
auto v_tile = load_tile_transpose(v_lds_read_window);
|
||||
|
||||
if constexpr(kHasUnevenSplits)
|
||||
{
|
||||
@@ -991,6 +958,18 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
static_for<0, 12, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
|
||||
});
|
||||
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ
|
||||
});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
|
||||
@@ -1079,8 +1058,9 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds<v_lds_insts>();
|
||||
move_tile_window(k_dram_window, {kN0, 0});
|
||||
k_lds_write_window.set_bottom_tensor_view_data_ptr(k_lds_write_ptr);
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
|
||||
if constexpr(1 < k1_loops)
|
||||
@@ -1108,15 +1088,16 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
sequence<kM0, k1_loops * kK1>{}),
|
||||
v_tile);
|
||||
|
||||
k_tile = load_tile(k_lds_read_window_next);
|
||||
|
||||
static_for<0, 14, 1>{}([&](auto i) {
|
||||
k_lds_read_window.set_bottom_tensor_view_data_ptr(k_lds_read_ptr);
|
||||
k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
static_for<0, 12, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ
|
||||
});
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
|
||||
|
||||
@@ -292,7 +292,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr auto v_lds_block_desc = [&]() {
|
||||
if constexpr(Xor)
|
||||
{
|
||||
constexpr auto XorGroupSize = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
constexpr auto XorGroupSize =
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
|
||||
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{}),
|
||||
@@ -303,27 +304,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
const auto v_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
v_lds_block_desc_naive,
|
||||
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<kNPerBlock / XorGroupSize>{},
|
||||
number<XorGroupSize>{}))),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto v_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
v_lds_block_desc_unmerged,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kKPerBlock>{},
|
||||
number<kNPerBlock / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(make_xor_transform(make_tuple(number<kKPerBlock>{},
|
||||
number<kNPerBlock / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
v_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kNPerBlock / XorGroupSize>{},
|
||||
number<XorGroupSize>{}))),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user