Improve XDL to WMMA porting for grouped conv fwd (#3456)

Refactors the way the number of XDL (matrix multiply-accumulate) instructions per wave is calculated and used in the grouped convolution forward implementations, especially to better support WMMA (Wave Matrix Multiply-Accumulate) instructions and 16x16 tiles. 
The changes use MXdlPerWave instead of NXdlPerWave to increase number of waves per M dim.
This commit is contained in:
Bartłomiej Kocot
2025-12-19 23:58:51 +01:00
committed by GitHub
parent 2d9c962e2c
commit cbc8335964
13 changed files with 226 additions and 133 deletions

View File

@@ -33,7 +33,7 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_64_64x32x32)
.with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x16x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);

View File

@@ -28,7 +28,7 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)