Improve XDL to WMMA porting for grouped conv fwd (#3456)

Refactors the way the number of XDL (matrix multiply-accumulate) instructions per wave is calculated and used in the grouped convolution forward implementations, especially to better support WMMA (Wave Matrix Multiply-Accumulate) instructions and 16x16 tiles. 
The changes use MXdlPerWave instead of NXdlPerWave to increase number of waves per M dim.
This commit is contained in:
Bartłomiej Kocot
2025-12-19 23:58:51 +01:00
committed by GitHub
parent 2d9c962e2c
commit cbc8335964
13 changed files with 226 additions and 133 deletions

View File

@@ -33,7 +33,7 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_64_64x32x32)
.with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x16x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);

View File

@@ -28,7 +28,7 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)

View File

@@ -111,8 +111,8 @@ struct DefaultAlgorithm
.bk1 = 8,
.m_per_xdl = 16,
.n_per_xdl = 16,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4};
.m_xdl_per_wave = 8,
.n_xdl_per_wave = 8};
ckb::test::TransferABC transfer{
.a =
@@ -188,7 +188,7 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
" ├─ Pipeline scheduler: INTRAWAVE\n"
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 16×16\n"
" │ └─ Number of warp gemm iterations: 4×4\n"
" │ └─ Number of warp gemm iterations: 8×8\n"
" └─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"

View File

@@ -68,7 +68,7 @@ constexpr TransferABC FwdTransfer_4x64x1{
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.scalar_per_vector = 4},
},
};