mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
* Update license year
* Add initial code to override decode problem
* Fix splitkv traits/args overriding error
* Reshape and transpose lse for decode
* Remove debug code
* Prettify example code
* Use better function name
* Add kMergeNumHeadGroupsSeqLenQ flag
Kernel user can use this switch to turn on/off optimization for
some problem sizes
* Add missing flag declarations
* Default turn off kMergeNumHeadGroupsSeqLenQ in codegen
* Group similar statements together
* Remove assumption of seqlen_q=1
* Remove kMergeNumHeadGroupsSeqLenQ from splitkv combine kernel
* Support kMergeNumHeadGroupsSeqLenQ=true in fmha splitkv kernel
* Run kMergeNumHeadGroupsSeqLenQ=true kernels when need
* Fix group mode block skip logics
* Undo changes of normal fwd kernel
* Update in GridSize() and using GridSize() for splitkv kernel (#1799)
---------
Co-authored-by: Qianfeng <qianfeng.zhang@amd.com>
[ROCm/composable_kernel commit: 24b12d04af]
47 lines
3.0 KiB
C++
47 lines
3.0 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
|
#include "ck_tile/ops/fmha/block/page_block_navigator.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
|
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
|
#include "ck_tile/ops/common/tensor_layout.hpp"
|