mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
tempsave
This commit is contained in:
@@ -297,10 +297,12 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
|||||||
Policy::template MakeSRegTileDistribution<Problem>());
|
Policy::template MakeSRegTileDistribution<Problem>());
|
||||||
|
|
||||||
// V tile in LDS
|
// V tile in LDS
|
||||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
auto [i_page_block_v, v_dram_block_window] = v_page_block_navigator.make_tile_window(
|
||||||
v_dram_block_window_lengths,
|
v_dram_block_window_lengths,
|
||||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
{0, aligned_physical_seqlen_k_start});
|
||||||
Policy::template MakeVDramTileDistribution<Problem>());
|
|
||||||
|
auto v_dram_window = make_tile_window(
|
||||||
|
v_dram_block_window, Policy::template MakeVDramTileDistribution<Problem>());
|
||||||
|
|
||||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||||
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
|
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
|
||||||
@@ -348,6 +350,9 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
|||||||
k_dram_window = make_tile_window(k_dram_block_window,
|
k_dram_window = make_tile_window(k_dram_block_window,
|
||||||
Policy::template MakeKDramTileDistribution<Problem>());
|
Policy::template MakeKDramTileDistribution<Problem>());
|
||||||
|
|
||||||
|
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
|
||||||
|
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
|
||||||
|
|
||||||
do
|
do
|
||||||
{
|
{
|
||||||
// STAGE 1, QK gemm
|
// STAGE 1, QK gemm
|
||||||
@@ -370,7 +375,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
|||||||
i_page_block_v =
|
i_page_block_v =
|
||||||
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
|
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
|
||||||
|
|
||||||
block_sync_lds_direct_load<v_dram_window.get_num_of_access()>();
|
// CK_PRINT<decltype(v_dram_window.get_num_of_access())>();
|
||||||
|
block_sync_lds_direct_load<v_vmem_insts>();
|
||||||
auto k_tile = load_tile(k_lds_read_window);
|
auto k_tile = load_tile(k_lds_read_window);
|
||||||
|
|
||||||
gemm_0(
|
gemm_0(
|
||||||
@@ -622,7 +628,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
block_sync_lds_direct_load<k_dram_window.get_num_of_access()>();
|
block_sync_lds_direct_load<k_vmem_insts>();
|
||||||
auto v_tile = load_tile_transpose(v_lds_read_window);
|
auto v_tile = load_tile_transpose(v_lds_read_window);
|
||||||
|
|
||||||
gemm_1(o_acc,
|
gemm_1(o_acc,
|
||||||
|
|||||||
@@ -14,7 +14,18 @@
|
|||||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
||||||
|
|
||||||
namespace ck_tile {
|
namespace ck_tile {
|
||||||
|
// Use `CK_PRINT<T1, T2, ...>()` to inspect values of type T1, T2, ...
|
||||||
|
// Use `CK_PRINT<v1, v2, ...>()` to inspect constexpr values of val1, val2, ... of the same type
|
||||||
|
// In a non-evaluated context, you can use `using _dummy = decltype(CK_PRINT<...>());`
|
||||||
|
// Set BUILD_DEV to OFF to avoid enabling Werror
|
||||||
|
template <auto... val>
|
||||||
|
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
template <typename... type>
|
||||||
|
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||||
|
{
|
||||||
|
}
|
||||||
// This pipeline is qkv all located in LDS
|
// This pipeline is qkv all located in LDS
|
||||||
struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||||
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||||
|
|||||||
Reference in New Issue
Block a user