mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
* Re-implement qr_ks_vs_async pipeline by using kLoadOnce * Remove last block_sync_lds() in the loop * Tiny adjustment in qr_ks_vs_async pipeline for better performance * Rename MakeQDramTileDistribution to MakeQRegTileDistribution for QLoadOnce pipeline * Use LDS as intermediary stop when loading Q from global memory for qr_ks_vs_async pipeline * Use un-rolled gemm for Gemm-0 * Use k0_loops small tile load/store to replace the big tile load/store for K * Remove the commented lines in qx_ks_vs_custom_policy.hpp * Tune the prefetching of V in qr_ks_vs_async pipeline * Move the codes for storing the first v_lds tile some later * Let BlockDropout reuse LDS with V * Switch to separate code blocks according to iteration index * Interleave code blocks for better performance * Move clear_tile(s_acc) for better interleaving * Move code interleaving * Use MakeQDramTileDistribution for q_dram_window * Roll-back to load Q directly from global memory instead of using LDS as intermediary stop * Let V reuse the LDS of K * Use array of tiles to represent Q in vgprs * Use QLoadOnce == false for qr_ks_vs_async pipeline * Special treatment for hdim-96 to save vgprs in qr_ks_vs_async pipeline * Define statically indexed array k_lds_windows[] to reduce the using of get_slice_tile() * Move the definition of v_tiles out from the loop * Define statically indexed array v_lds_windows[] to reduce using of get_slice_tile() * Remove using KLoadOnce in qx_ks_vs_custom_policy * Remove un-used get_slice_tile() call * Move the code line of clear_tile(s_acc) * Tune the lines of codes to make them more tidy * Re-arrange the codes before the main-loop * Add comments * Unify the alignment to be 8 for Q/K/V Lds decriptors * Tuning to K pre-loading * Tune K Lds and V Lds reuse for kPreloadWholeNextIterationK == false * Adjust the pipeline codes * Use NumPrefetchV to separate from NumVLdsBuffers * Tune the location of a scheduler barrier code line * Prefetch first v_tile at earlier time for both kPreloadNextWholeIterationK true/false paths * Adjust the using of kPadSeqLenQ and kPadSeqLenK in the kernel * Use __builtin_amdgcn_sched_barrier(0x7f) in the pipeline * Move the location for store_tile() of first v_tile * Rename the qr_ks_vs_async pipeline to qr_ks_vs_whole_k_prefetch pipeline * Re-add NumPrefetchK as template for BlockFmhaPipelineQXKSVSCustomPolicy<> * Try to fix old bugs in qx_ks_vs_custom_policy * Remove K_LDS_LOAD_USE_OFFSET_TRANSFORM code-path to make qr_ks_vs_async and qx_ks_vs_custom_policy simpler * Fix in MakeKDramTileDistribution() in qx_ks_vs_custom_policy * Update to LdsBufferSequence and introduce NumKVLdsBuffers for max(NumPrefetchK, NumPrefetchV) * Tiny Fix (#1888) * Ck tile/paged attention workaround (#1894) * Correction in GetRangeAlongX() * Work-around to solve the failures in test_paged_attention_ck in xformers * Tiny code adjustment in the qr_ks_vs_whole_k_prefetch pipeline * Remove one call of move_tile_window for q_dram_window * Refine the codes in GetNumPrefetchV()/GetNumKLdsBuffers() * Tiny fix in qr_ks_vs_whole_k_prefetch pipeline * Adjust the location of codes for storing the first V tile to LDS * Tiny fix and add comments * Change GetSmemKPackK size to improve performance * Move the codes related to K-Lds to the pipeline default policy due to some override on the generic custom_policy * Update MakeKDramTileDistribution() and MakeKLdsDescriptor() to completely remove bank conflicts for K-Lds access * Adjustment in intermediate iteration codes for tiny performance improvement * Reduce the number of VLds buffers to 2 for whole_k_prefetch situtation * Use IsFirstKLdsBufferOverlapLastVLdsBuffer() to avoid potential Lds issue * Adjust the code location for calling IsFirstKLdsBufferOverlapLastVLdsBuffer() * Remove useless AsyncopyV * Rename MakeQDramTileDistribution to MakeQRegTileDistribution when LDS is not used * Keep qx_ks_vs_custom_policy work for other pipelines and move whole_k_prefetch specific codes to whole_k_prefetch default policy * Recover the qr_ks_vs_async pipeline * Recover qr_ks_vs_async in fmha.hpp and tiny fix in qr_ks_vs pipeline * Revert "Try to fix old bugs in qx_ks_vs_custom_policy" This reverts commit39b82ca194. * Tiny fix with regard to whole_k_prefetch pipeline compiling * Update kPadSeqLenK setting in fmha_fwd_kernel * Use q_element_func and k_element_func * Use single q_tile rather than multiple sliced q_tiles * Codes refine according to the comments * Re-format one file * Mark qr_ks_vs_whole_k_prefetch as QLoadOnec == true [ROCm/composable_kernel commit:4f54fa3058]
50 lines
3.3 KiB
C++
50 lines
3.3 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
|
#include "ck_tile/ops/fmha/block/page_block_navigator.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
|
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
|
#include "ck_tile/ops/common/tensor_layout.hpp"
|
|
#include "ck_tile/ops/common/utils.hpp"
|