temp save

This commit is contained in:
aska-0096
2025-07-17 10:06:09 +00:00
parent 7e330553dc
commit 94b6430489
11 changed files with 298 additions and 325 deletions

View File

@@ -757,7 +757,9 @@ struct FmhaFwdDecodeKernel
const auto make_k_dram = [&](const KDataType* data, index_t height) {
// We don't expect K data reuse among different blocks in decode case.
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global, memory_operation_enum::set, amd_buffer_coherence_enum::SYSTEM_NT1>(
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
amd_buffer_coherence_enum::SYSTEM_NT1>(
data, // will update this pointer if using paged-kvcache
make_tuple(height, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
@@ -784,12 +786,15 @@ struct FmhaFwdDecodeKernel
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
// We don't expect V data reuse among different blocks in decode case.
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global, memory_operation_enum::set, amd_buffer_coherence_enum::SYSTEM_NT1>(
data, // will update this pointer if using paged-kvcache
make_tuple(length, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_naive =
make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
amd_buffer_coherence_enum::SYSTEM_NT1>(
data, // will update this pointer if using paged-kvcache
make_tuple(length, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
@@ -1079,15 +1084,15 @@ struct FmhaFwdDecodeKernel
v_page_block_navigator, // Remove it
bias_dram_window,
lse_acc_dram_window,
kargs.num_splits, // Remove it
i_split_, // Remove it
kargs.num_splits, // Remove it
i_split_, // Remove it
mask,
position_encoding,
kargs.scale_s,
variant, // Remove it
variant_params, // Remove it
block_indices, // Remove it
kv_l2p_offset, // Remove it
variant, // Remove it
variant_params, // Remove it
block_indices, // Remove it
kv_l2p_offset, // Remove it
smem_ptr);
}
}();

View File

@@ -1071,20 +1071,20 @@ struct FmhaFwdSplitKVKernel
{
return FmhaPipeline{}(q_dram_window,
k_dram_window_lengths,
// k_page_block_navigator,
// k_page_block_navigator,
v_dram_window_lengths,
// v_page_block_navigator,
// v_page_block_navigator,
bias_dram_window,
lse_acc_dram_window,
// kargs.num_splits,
// i_split_,
// kargs.num_splits,
// i_split_,
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
// kv_l2p_offset,
// kv_l2p_offset,
smem_ptr);
}
}();

View File

@@ -11,8 +11,7 @@
namespace ck_tile {
// This pipeline is qkv all located in LDS
template <typename Problem_,
typename Policy_ = BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy>
template <typename Problem_, typename Policy_ = BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy>
struct BlockFmhaFwdDecodePipelineQRKSVS
{
using Problem = remove_cvref_t<Problem_>;
@@ -52,11 +51,13 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
// static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
// Problem::kPadHeadDimV == true);
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; // support multiple of vector(like 8x)
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ =
Problem::kPadHeadDimQ; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV =
Problem::kPadHeadDimV; // support multiple of vector(like 8x)
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
@@ -130,33 +131,17 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const KElementFunction& k_element_func,
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
const LSEaccElementFunction& lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
index_t num_splits,
index_t i_split,
FmhaMask mask,
@@ -184,56 +169,11 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr =
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
// K tile in LDS
KDataType* k_lds_ptr =
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
max(Policy::template GetSmemSizeQ<Problem>(),
Policy::template GetSmemSizeK<Problem>())),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// S tile in LDS
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptr) +
max(Policy::template GetSmemSizeQ<Problem>(),
Policy::template GetSmemSizeK<Problem>())),
Policy::template MakeSLdsBlockDescriptor<Problem>());
auto s_write_lds_window = make_tile_window(
s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto s_read_lds_window =
make_tile_window(s_lds,
Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeSRegTileDistribution<Problem>());
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>());
// load Q here, will store Q into LDS to maximize throughput
auto origin_q = load_tile(q_dram_window);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
@@ -259,7 +199,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
@@ -279,8 +219,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
if(get_thread_local_1d_id() < kM0)
{
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
store_tile(lse_acc_dram_window_tmp, lse_acc);
}
}
@@ -290,6 +229,25 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
}
}
// Q tile in LDS
auto q_dram_window = make_tile_window(
q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem>());
auto q_lds = make_tensor_view<address_space_enum::lds>(
static_cast<QDataType*>(smem_ptr), Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_store_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeQRegTileDistribution<Problem>());
async_load_tile(q_lds_store_window, q_dram_window);
// K tile in LDS
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
// make sure the first tile is completely located in page-block (page-block size should be
@@ -307,12 +265,59 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
return physical_seqlen_k_start_;
}
}();
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window, Policy::template MakeKDramTileDistribution<Problem>());
auto k_lds = make_tensor_view<address_space_enum::lds>(
static_cast<KDataType*>(smem_ptr), Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds,
make_tuple(number<kN0>{}, number<kK0>{}),
{0, 0},
Policy::template MakeKRegTileDistribution<Problem>());
// S tile in LDS
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptr) +
max(Policy::template GetSmemSizeQ<Problem>(),
Policy::template GetSmemSizeK<Problem>())),
Policy::template MakeSLdsBlockDescriptor<Problem>());
auto s_write_lds_window = make_tile_window(
s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto s_read_lds_window =
make_tile_window(s_lds,
Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeSRegTileDistribution<Problem>());
// V tile in LDS
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
max(Policy::template GetSmemSizeQ<Problem>(),
Policy::template GetSmemSizeK<Problem>()) +
Policy::template GetSmemSizeS<Problem>()),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_write_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto v_lds_read_window =
make_tile_window(v_lds,
Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeVRegTileDistribution<Problem>());
// Bias tile in LDS
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -322,51 +327,26 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
block_sync_lds_direct_load<0>();
auto q_tile = load_tile(q_lds_read_window);
// store Q into LDS
__builtin_amdgcn_sched_barrier(0);
auto q_lds_window_for_store = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
store_tile(q_lds_window_for_store, origin_q);
__builtin_amdgcn_sched_barrier(0);
// load Q from LDS
__builtin_amdgcn_sched_barrier(0);
auto q_lds_window_for_load =
make_tile_window(q_lds,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeQRegTileDistribution<Problem>());
block_sync_lds();
auto q = load_tile(q_lds_window_for_load);
__builtin_amdgcn_sched_barrier(0);
auto q_tile = tile_elementwise_in(q_element_func, q);
// prefetch K tile
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
static_assert(1 == k0_loops);
static_assert(1 == k1_loops);
auto k_dram_window = make_tile_window(
k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
async_load_tile(k_lds_write_window, k_dram_window);
// move K tile windows
i_page_block_k =
k_page_block_navigator.move_tile_window(i_page_block_k, k_dram_block_window, {kN0, 0});
// load the first tile of the first iteration and store to LDS
auto k_block_tile = load_tile(k_dram_window);
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0});
// ensure LDS access by Q is done before the over-writting by K
block_sync_lds();
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
k_dram_window = make_tile_window(k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>());
do
{
@@ -385,40 +365,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
k_block_tile = load_tile(k_dram_window); // global read i + 1
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window);
block_sync_lds();
move_tile_window(k_dram_window, {0, kK0});
async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
// move V tile windows
i_page_block_v =
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
store_tile(
k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
});
}
block_sync_lds_direct_load<v_dram_window.get_num_of_access()>();
auto k_tile = load_tile(k_lds_read_window);
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_lds_window);
}
gemm_0(
s_acc,
get_slice_tile(
q_tile, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(
k_tile, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kN0, k0_loops * kK0>{}));
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
@@ -437,7 +401,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
@@ -455,7 +418,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform =
@@ -530,39 +492,30 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
}
}
__builtin_amdgcn_sched_barrier(0);
async_load_tile(k_lds_write_window, k_dram_window);
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
// load the first tile for next iteration
if(i_total_loops < num_total_loop - 1)
{
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
k_dram_window = make_tile_window(k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>());
k_dram_window = make_tile_window(
k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window
// laod the first tile of the first iteration and store to LDS
k_block_tile = load_tile(k_dram_window);
}
__builtin_amdgcn_sched_barrier(0);
// In Nwarp=1 and NXdl=32, GEMM0 output naturally fit the input of GEMM1
// Otherwise shuffle through LDS so that the tile layout is consistent with required by Gemm1
auto s_new = [&](){
if constexpr ( !((kNWarp==1) && (kNXdl == 32)) ){
// Otherwise shuffle through LDS so that the tile layout is consistent with required by
// Gemm1
auto s_new = [&]() {
if constexpr(!((kNWarp == 1) && (kNXdl == 32)))
{
auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
store_tile(s_write_lds_window, s);
block_sync_lds();
return load_tile(s_read_lds_window);
}
else{
else
{
return cast_tile<SMPLComputeDataType>(s_acc); // S{j}
}
}();
}();
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s_new,
@@ -630,8 +583,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
const auto p = cast_tile<PDataType>(p_compute);
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
@@ -670,79 +622,16 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
});
});
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_prefetch);
store_tile(
v_lds_window,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
}
i_page_block_v =
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
block_sync_lds_direct_load<k_dram_window.get_num_of_access()>();
auto v_tile = load_tile_transpose(v_lds_read_window);
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&,
&i_page_block_v_ = i_page_block_v,
&v_dram_window_ = v_dram_window](auto i_k1) {
const auto v = load_tile(v_dram_window_); // load next v
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
get_slice_tile(v_tile,
sequence<0, (k1_loops - 1) * kK1>{},
sequence<kN1, k1_loops * kK1>{}));
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v);
store_tile(v_lds_window,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
}
i_page_block_v_ = v_page_block_navigator.move_tile_window(
i_page_block_v_, v_dram_window_, {0, kK1});
});
}
// tail
{
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
v_lds_window);
block_sync_lds();
}
__builtin_amdgcn_sched_barrier(0);
// load the first tile for next iteration
if(i_total_loops < num_total_loop - 1)
{
// store the first tile for next iteration to LDS
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0});
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
}
} while(++i_total_loops < num_total_loop);
if constexpr(kStoreLSE)
@@ -777,8 +666,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
if(get_thread_local_1d_id() < kM0)
{
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
store_tile(lse_acc_dram_window_tmp, lse_acc);
}
}
@@ -802,66 +690,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowLengths,
typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_lengths,
k_page_block_navigator,
identity{},
v_dram_block_window_lengths,
v_page_block_navigator,
identity{},
bias_dram_block_window_tmp,
identity{},
lse_acc_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
num_splits,
i_split,
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
}
};
} // namespace ck_tile

View File

@@ -7,6 +7,11 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
namespace ck_tile {
@@ -116,6 +121,137 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
return q_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using GemmProblem =
BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
true,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPVBlockGemm()
{
using GemmProblem =
BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto k_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
return k_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
return v_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS()
{