[rocm-libraries] ROCm/rocm-libraries#4582 (commit 990a00d)

[CK_Builder] added bwd data kernels to builder factory
 (#4582)

This PR adds bwd data wmma and xdl kernels to the ck builder, their
instance and conv traits as well as tests for the above.
This commit is contained in:
kabrahamAMD
2026-02-27 03:06:29 +00:00
committed by assistant-librarian[bot]
parent c8a8449eec
commit 5e06874aae
34 changed files with 2511 additions and 104 deletions

View File

@@ -249,6 +249,26 @@ constexpr Transfer<> Transfer_4x32x1{
},
};
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_4x4_per_wave{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_4x2_per_wave{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}};
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_2x2_per_wave{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_2x1_per_wave{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}};
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{
.k1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
@@ -283,6 +303,13 @@ constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{
constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{
.k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
constexpr GridwiseWmmaGemmABK1 GemmParamsABK1_Wmma_16x16_2x1_per_wave{.ak1 = 8,
.bk1 = 8,
.m_per_wmma = 16,
.n_per_wmma = 16,
.m_wmma_per_wave = 2,
.n_wmma_per_wave = 1};
constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};