mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Improve XDL to WMMA porting for grouped conv fwd (#3456)
Refactors the way the number of XDL (matrix multiply-accumulate) instructions per wave is calculated and used in the grouped convolution forward implementations, especially to better support WMMA (Wave Matrix Multiply-Accumulate) instructions and 16x16 tiles. The changes use MXdlPerWave instead of NXdlPerWave to increase number of waves per M dim.
This commit is contained in:
@@ -33,7 +33,7 @@ TEST(FwdConvInstances,
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_64_64x32x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
|
||||
@@ -28,7 +28,7 @@ TEST(FwdConvInstances,
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
|
||||
Reference in New Issue
Block a user