mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Improve the performance of qr_ks_vs_whole_k_prefetch pipeline (#6209) ## About qr_ks_vs_whole_k_prefetch pipeline This PR updates and enhances the qr_ks_vs_whole_k_prefetch pipeline to improve performance on both MI350 GPUs through better MFMA instruction usage, transposed V-loading support, and N0-loop implementation. The pipeline targets scenarios where the number of workgroups is low, enabling better CU occupancy by using smaller MTile sizes (kM0=64 vs 128) while prefetching entire K tiles. ## Changes: - Adds transposed V-loading support (qr_ks_vs_whole_k_prefetch_trload) to avoid using shuffle instructions on MI350 - Implements N0-loop based Gemm0 to reduce tile window movement overhead and eliminate `clear_tile` calls - Adds full support for hdim96/hdim160 without padding requirements - Updates MFMA instruction selection to ensure optimal choices for MI350 ## Performance results 1. For attention shapes which leads to kM0=64, `qr_ks_vs_async_whole_k_prefetch_trload` shows much better performance than `qr_ks_vs_async_trload` on the same case (execution time `41.02ms` by whole_k_prefetch_trload & `58.50ms` by async_load), and `qr_ks_vs_async_whole_k_prefetch_trload` also shows obviously better performance than the recently tuned `qr_ks_vs_async` on the same case (execution time `41.02ms` by whole_k_prefetch_trload 7 `47.60ms` by qr_ks_vs_async) 2. Also on MI300, for attention shapes which leads to kM0=64, `qr_ks_vs_async_whole_k_prefetch` shows much better performance than the `qr_ks_vs_async` (which is supposed to be very high-efficient) on the same case (execution time `64.50ms` by whole_k_prefetch & `80.20ms` by qr_ks_vs_async) 3. For attention shapes which leads to kM0=128, `qr_ks_vs_async_whole_k_prefetch_trload` show a little bit better performance than `qr_ks_vs_async` on mi350 (execution time `104.50ms` by whole_k_prefetch_trload & `106.50ms` by qr_ks_vs_async). And they shows completely on-par performance on MI300 ## Test/Verify 1. Use the ROCM xformers branch `test_whole_k_prefetch_n0loop` to test/verify qr_ks_vs_whole_k_prefetch pipeline since this pipeline can not be used by ck_tile fmha example so far 2. Use the following command-line for building/testing xformers >```bash > #> git clone -b test_whole_k_prefetch_n0loop https://github.com/ROCm/xformers > #> git submodule update --init --recursive > #> pip install --no-build-isolation -e ./ > #> pytest tests/test_mem_eff_attention.py::test_forward >``` 4. Any scripts which can run on xformers can be used to evaluate qr_ks_vs_whole_k_prefetch pipeline. Using the two environ variable to switch from using different pipelines > ```bash > #> export FMHA_DISABLE_SPECIAL_TREATMENT=1 #> to disable using FAV3 and qr_ks_vs_async_trload pipeline > #> export FMHA_ENABLE_ASYNC_PIPELINE=1 #> to disable using qr_ks_vs_async pipeline for comparing > ``` ## Discussion
71 lines
4.9 KiB
C++
71 lines
4.9 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
#pragma once
|
|
|
|
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
|
|
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
|
#include "ck_tile/ops/fmha/block/cast_tile_mx.hpp"
|
|
#include "ck_tile/ops/fmha/block/page_block_navigator.hpp"
|
|
#include "ck_tile/ops/fmha/block/variants.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
|
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
|
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
|
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
|
#include "ck_tile/ops/common/streamk_common.hpp"
|
|
#include "ck_tile/ops/common/tensor_layout.hpp"
|
|
#include "ck_tile/ops/common/utils.hpp"
|