mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 16:59:10 +00:00
[CK Tile] Add conv Wavelet GEMM pipeline and bwd_weight instances (#7937) ## Motivation CK Tile had no pipeline competitive with old CK's wavelet on the RetinaNet K=36 C=256 3x3 conv bwd_weight class. This adds a wave-specialized "wavelet" GEMM pipeline so CK Tile has a competitive kernel for spatial small-K shapes. ## Technical Details - New wavelet GEMM pipeline (`gemm_pipeline_ag_bg_cr_wavelet.hpp`): workgroup split into math waves (LDS read + MFMA) and load waves (DRAM read + LDS write). - VGPR role-split: `operator()` has two top-level mutually-exclusive `is_math` branches so the allocator overlays both roles onto the same physical VGPRs, cutting arch VGPR ~33-40% and raising occupancy. Correctness depends on identical `block_sync_lds` counts on both arms plus a matching load-wave barrier stub in the epilogue (`cshuffle_epilogue.hpp`). - Kernel dispatch (`grouped_convolution_backward_weight_kernel.hpp`): `kIsWavelet` path, `LaunchBlockSize`, load-wave barrier stub. Uplift: wavelet is the fastest CK Tile pipeline on the RetinaNet K=36 C=256 3x3 family, beating the best non-wavelet CK Tile kernel by 10-27% (googlenet K=320 by 16-23%); the role-split roughly halves the parity gap vs old CK on the 13x13 fp16 shape. ## Test Plan - `ckProfiler grouped_conv_bwd_weight`, NHWGC layout, fp16/bf16, `split_k=all`, CPU verify on RetinaNet K=36 shapes (7x7, 13x13) and a broad 2D sweep. - Correctness: `-v=1` across `split_k` in {-1,1,2,4,8,16,32,64} (barrier-parity / deadlock check). - `test_grouped_convnd_bwd_weight` over the tests `.conf` wavelet instances. ## Test Result - All wavelet instances CPU-verify correct across the split-K sweep; no hangs (dual-arm barrier sequence matches). - Wavelet wins the RetinaNet K=36 C=256 3x3 family (10-27% over best non-wavelet CK Tile) and googlenet K=320 (16-23%); at parity-or-better vs old CK on the majority of spatial shapes. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
106 lines
7.2 KiB
C++
106 lines
7.2 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
#pragma once
|
|
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_mx_areg_bsmem_creg_v1_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_v2.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_tdm_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_tdm_v1.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_tdm_v2.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_wavelet.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_async_v1.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_tdm.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_tdm_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_highprec_traits.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
|
|
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
|
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
|
#include "ck_tile/ops/common/streamk_common.hpp"
|
|
#include "ck_tile/ops/common/tensor_layout.hpp"
|
|
#include "ck_tile/ops/common/utils.hpp"
|