[rocm-libraries] ROCm/rocm-libraries#4582 (commit 990a00d)

[CK_Builder] added bwd data kernels to builder factory
 (#4582)

This PR adds bwd data wmma and xdl kernels to the ck builder, their
instance and conv traits as well as tests for the above.
This commit is contained in:
kabrahamAMD
2026-02-27 03:06:29 +00:00
committed by assistant-librarian[bot]
parent c8a8449eec
commit 5e06874aae
34 changed files with 2511 additions and 104 deletions

View File

@@ -282,38 +282,39 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 16×16\n"
" │ └─ Number of warp gemm iterations: 8×8\n"
" ─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
" │ ├─ Vector access (LDS write) instruction size: 2\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
" │ ├─ Vector access (LDS write) instruction size: 2\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" ─ Vector access (GMEM write) instruction size: 2\n"
" ─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
" │ ├─ Vector access (LDS write) instruction size: 2\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
" │ ├─ Vector access (LDS write) instruction size: 2\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" ─ Vector access (GMEM write) instruction size: 2\n"
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
"parameter\n"
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
" ├─ Struct does not contain optional max_transpose_transfer_dst_scalar_per_vector "
"parameter\n"
" └─ Struct does not contain optional num_groups_to_merge parameter"));
}
// Test printing of optional parameters num_groups_to_merge,
// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector
// max_transpose_transfer_src_scalar_per_vector and max_transpose_transfer_dst_scalar_per_vector
TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest)
{
using Instance =
@@ -390,29 +391,29 @@ TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest)
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 32×32\n"
" │ └─ Number of warp gemm iterations: 4×4\n"
" ─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" ─ Vector access (GMEM write) instruction size: 8\n"
" ─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" ─ Vector access (GMEM write) instruction size: 8\n"
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
" ├─ Max Transpose transfer scr scalar per vector: 1\n"
" ├─ Max Transpose dst scalar per vector: 1\n"
@@ -494,33 +495,34 @@ TEST(ConvDescriptionTest, BwdWeightWmmaCshuffleV3DescriptionTest)
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 32×32\n"
" │ └─ Number of warp gemm iterations: 4×4\n"
" ─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" ─ Vector access (GMEM write) instruction size: 8\n"
" ─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" ─ Vector access (GMEM write) instruction size: 8\n"
" ├─ Num gemm k prefetch stage: 1\n"
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
"parameter\n"
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
" ├─ Struct does not contain optional max_transpose_transfer_dst_scalar_per_vector "
"parameter\n"
" └─ Struct does not contain optional num_groups_to_merge parameter"));
}