This commit is contained in:
aska-0096
2025-07-18 05:16:39 +00:00
parent 94b6430489
commit ae39c84f55
2 changed files with 23 additions and 6 deletions

View File

@@ -297,10 +297,12 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
Policy::template MakeSRegTileDistribution<Problem>());
// V tile in LDS
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
auto [i_page_block_v, v_dram_block_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
{0, aligned_physical_seqlen_k_start});
auto v_dram_window = make_tile_window(
v_dram_block_window, Policy::template MakeVDramTileDistribution<Problem>());
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
@@ -348,6 +350,9 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
k_dram_window = make_tile_window(k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>());
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
do
{
// STAGE 1, QK gemm
@@ -370,7 +375,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
i_page_block_v =
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
block_sync_lds_direct_load<v_dram_window.get_num_of_access()>();
// CK_PRINT<decltype(v_dram_window.get_num_of_access())>();
block_sync_lds_direct_load<v_vmem_insts>();
auto k_tile = load_tile(k_lds_read_window);
gemm_0(
@@ -622,7 +628,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
});
});
block_sync_lds_direct_load<k_dram_window.get_num_of_access()>();
block_sync_lds_direct_load<k_vmem_insts>();
auto v_tile = load_tile_transpose(v_lds_read_window);
gemm_1(o_acc,

View File

@@ -14,7 +14,18 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
namespace ck_tile {
// Use `CK_PRINT<T1, T2, ...>()` to inspect values of type T1, T2, ...
// Use `CK_PRINT<v1, v2, ...>()` to inspect constexpr values of val1, val2, ... of the same type
// In a non-evaluated context, you can use `using _dummy = decltype(CK_PRINT<...>());`
// Set BUILD_DEV to OFF to avoid enabling Werror
template <auto... val>
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
{
}
template <typename... type>
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
{
}
// This pipeline is qkv all located in LDS
struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,