mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-10 16:28:38 +00:00
[CK Tile] Add conv Wavelet GEMM pipeline and bwd_weight instances (#7937) ## Motivation CK Tile had no pipeline competitive with old CK's wavelet on the RetinaNet K=36 C=256 3x3 conv bwd_weight class. This adds a wave-specialized "wavelet" GEMM pipeline so CK Tile has a competitive kernel for spatial small-K shapes. ## Technical Details - New wavelet GEMM pipeline (`gemm_pipeline_ag_bg_cr_wavelet.hpp`): workgroup split into math waves (LDS read + MFMA) and load waves (DRAM read + LDS write). - VGPR role-split: `operator()` has two top-level mutually-exclusive `is_math` branches so the allocator overlays both roles onto the same physical VGPRs, cutting arch VGPR ~33-40% and raising occupancy. Correctness depends on identical `block_sync_lds` counts on both arms plus a matching load-wave barrier stub in the epilogue (`cshuffle_epilogue.hpp`). - Kernel dispatch (`grouped_convolution_backward_weight_kernel.hpp`): `kIsWavelet` path, `LaunchBlockSize`, load-wave barrier stub. Uplift: wavelet is the fastest CK Tile pipeline on the RetinaNet K=36 C=256 3x3 family, beating the best non-wavelet CK Tile kernel by 10-27% (googlenet K=320 by 16-23%); the role-split roughly halves the parity gap vs old CK on the 13x13 fp16 shape. ## Test Plan - `ckProfiler grouped_conv_bwd_weight`, NHWGC layout, fp16/bf16, `split_k=all`, CPU verify on RetinaNet K=36 shapes (7x7, 13x13) and a broad 2D sweep. - Correctness: `-v=1` across `split_k` in {-1,1,2,4,8,16,32,64} (barrier-parity / deadlock check). - `test_grouped_convnd_bwd_weight` over the tests `.conf` wavelet instances. ## Test Result - All wavelet instances CPU-verify correct across the split-K sweep; no hangs (dual-arm barrier sequence matches). - Wavelet wins the RetinaNet K=36 C=256 3x3 family (10-27% over best non-wavelet CK Tile) and googlenet K=320 (16-23%); at parity-or-better vs old CK on the majority of spatial shapes. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
141 lines
7.0 KiB
C++
141 lines
7.0 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
#pragma once
|
|
|
|
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
|
|
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
|
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
|
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
|
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
|
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
|
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
|
#include "ck_tile/core/arch/amd_buffer_coherence.hpp"
|
|
#include "ck_tile/core/arch/amd_cluster_load.hpp"
|
|
#include "ck_tile/core/arch/amd_tdm_descriptor.hpp"
|
|
#include "ck_tile/core/arch/amd_transpose_load_encoding.hpp"
|
|
#include "ck_tile/core/arch/amd_wave_read_first_lane.hpp"
|
|
#include "ck_tile/core/arch/arch.hpp"
|
|
#include "ck_tile/core/arch/barrier.hpp"
|
|
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
|
#include "ck_tile/core/arch/inst_prefetch.hpp"
|
|
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
|
#include "ck_tile/core/arch/mma/mfma/mfma.hpp"
|
|
#include "ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp"
|
|
#include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp"
|
|
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
|
|
#include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp"
|
|
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
|
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
|
|
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
|
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
|
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
|
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/mfma/selector.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/scale.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
|
|
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
|
|
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
|
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
|
|
#include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp"
|
|
#include "ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp"
|
|
#include "ck_tile/core/arch/mma/wmma/wmma_selector.hpp"
|
|
#include "ck_tile/core/arch/mma/wmma/wmma_traits.hpp"
|
|
#include "ck_tile/core/arch/mma/wmma/wmma_transforms.hpp"
|
|
#include "ck_tile/core/arch/utility.hpp"
|
|
#include "ck_tile/core/arch/workgroup_barrier.hpp"
|
|
#include "ck_tile/core/config.hpp"
|
|
#include "ck_tile/core/container/array.hpp"
|
|
#include "ck_tile/core/container/container_helper.hpp"
|
|
#include "ck_tile/core/container/map.hpp"
|
|
#include "ck_tile/core/container/meta_data_buffer.hpp"
|
|
#include "ck_tile/core/container/multi_index.hpp"
|
|
#include "ck_tile/core/container/sequence.hpp"
|
|
#include "ck_tile/core/container/span.hpp"
|
|
#include "ck_tile/core/container/static_array.hpp"
|
|
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
|
#include "ck_tile/core/container/thread_buffer.hpp"
|
|
#include "ck_tile/core/container/tuple.hpp"
|
|
#include "ck_tile/core/numeric/bfloat16.hpp"
|
|
#include "ck_tile/core/numeric/e4m3.hpp"
|
|
#include "ck_tile/core/numeric/e5m3.hpp"
|
|
#include "ck_tile/core/numeric/e8m0.hpp"
|
|
#include "ck_tile/core/numeric/ext_vector_base.hpp"
|
|
#include "ck_tile/core/numeric/float8.hpp"
|
|
#include "ck_tile/core/numeric/float8_ext.hpp"
|
|
#include "ck_tile/core/numeric/half.hpp"
|
|
#include "ck_tile/core/numeric/int8.hpp"
|
|
#include "ck_tile/core/numeric/integer.hpp"
|
|
#include "ck_tile/core/numeric/integral_constant.hpp"
|
|
#include "ck_tile/core/numeric/math.hpp"
|
|
#include "ck_tile/core/numeric/math_v2.hpp"
|
|
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
|
#include "ck_tile/core/numeric/mxfp_scale.hpp"
|
|
#include "ck_tile/core/numeric/null_type.hpp"
|
|
#include "ck_tile/core/numeric/numeric.hpp"
|
|
#include "ck_tile/core/numeric/pk_f6.hpp"
|
|
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
|
#include "ck_tile/core/numeric/pk_int4.hpp"
|
|
#include "ck_tile/core/numeric/scale_util.hpp"
|
|
#include "ck_tile/core/numeric/type_convert.hpp"
|
|
#include "ck_tile/core/numeric/vector_type.hpp"
|
|
#include "ck_tile/core/tensor/buffer_view.hpp"
|
|
#include "ck_tile/core/tensor/load_tile.hpp"
|
|
#include "ck_tile/core/tensor/load_tile_transpose.hpp"
|
|
#include "ck_tile/core/tensor/null_tensor.hpp"
|
|
#include "ck_tile/core/tensor/null_tile_window.hpp"
|
|
#include "ck_tile/core/tensor/shuffle_tile.hpp"
|
|
#include "ck_tile/core/tensor/slice_tile.hpp"
|
|
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
|
#include "ck_tile/core/tensor/store_tile.hpp"
|
|
#include "ck_tile/core/tensor/sweep_tile.hpp"
|
|
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
|
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
|
#include "ck_tile/core/tensor/tensor_coordinate.hpp"
|
|
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
|
|
#include "ck_tile/core/tensor/tensor_view.hpp"
|
|
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
|
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
|
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
|
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
|
#include "ck_tile/core/tensor/tile_window.hpp"
|
|
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
|
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
|
#include "ck_tile/core/tensor/tile_window_utils.hpp"
|
|
#include "ck_tile/core/tensor/transpose_tile.hpp"
|
|
#include "ck_tile/core/tensor/update_tile.hpp"
|
|
#include "ck_tile/core/utility/bit_cast.hpp"
|
|
#include "ck_tile/core/utility/data_cache_prefetch.hpp"
|
|
#include "ck_tile/core/utility/debug.hpp"
|
|
#include "ck_tile/core/utility/env.hpp"
|
|
#include "ck_tile/core/utility/functional.hpp"
|
|
#include "ck_tile/core/utility/functional_with_tuple.hpp"
|
|
#include "ck_tile/core/utility/gemm_validation.hpp"
|
|
#include "ck_tile/core/utility/ignore.hpp"
|
|
#include "ck_tile/core/utility/literals.hpp"
|
|
#include "ck_tile/core/utility/magic_div.hpp"
|
|
#include "ck_tile/core/utility/mixed_prec_compute_type.hpp"
|
|
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
|
|
#include "ck_tile/core/utility/philox_rand.hpp"
|
|
#include "ck_tile/core/utility/print.hpp"
|
|
#include "ck_tile/core/utility/random.hpp"
|
|
#include "ck_tile/core/utility/reduce_operator.hpp"
|
|
#include "ck_tile/core/utility/reduce_operator_accumulate.hpp"
|
|
#include "ck_tile/core/utility/static_counter.hpp"
|
|
#include "ck_tile/core/utility/to_sequence.hpp"
|
|
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
|
#include "ck_tile/core/utility/type_traits.hpp"
|
|
#include "ck_tile/core/utility/unary_element_function.hpp"
|