mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-25 09:37:42 +00:00
tempsave
This commit is contained in:
@@ -297,10 +297,12 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
Policy::template MakeSRegTileDistribution<Problem>());
|
||||
|
||||
// V tile in LDS
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
auto [i_page_block_v, v_dram_block_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
{0, aligned_physical_seqlen_k_start});
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window, Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
|
||||
@@ -348,6 +350,9 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
k_dram_window = make_tile_window(k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
|
||||
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
do
|
||||
{
|
||||
// STAGE 1, QK gemm
|
||||
@@ -370,7 +375,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
i_page_block_v =
|
||||
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
|
||||
|
||||
block_sync_lds_direct_load<v_dram_window.get_num_of_access()>();
|
||||
// CK_PRINT<decltype(v_dram_window.get_num_of_access())>();
|
||||
block_sync_lds_direct_load<v_vmem_insts>();
|
||||
auto k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
gemm_0(
|
||||
@@ -622,7 +628,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds_direct_load<k_dram_window.get_num_of_access()>();
|
||||
block_sync_lds_direct_load<k_vmem_insts>();
|
||||
auto v_tile = load_tile_transpose(v_lds_read_window);
|
||||
|
||||
gemm_1(o_acc,
|
||||
|
||||
@@ -14,7 +14,18 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Use `CK_PRINT<T1, T2, ...>()` to inspect values of type T1, T2, ...
|
||||
// Use `CK_PRINT<v1, v2, ...>()` to inspect constexpr values of val1, val2, ... of the same type
|
||||
// In a non-evaluated context, you can use `using _dummy = decltype(CK_PRINT<...>());`
|
||||
// Set BUILD_DEV to OFF to avoid enabling Werror
|
||||
template <auto... val>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
template <typename... type>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
// This pipeline is qkv all located in LDS
|
||||
struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
|
||||
Reference in New Issue
Block a user