mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[rocm-libraries] ROCm/rocm-libraries#4582 (commit 990a00d)
[CK_Builder] added bwd data kernels to builder factory (#4582) This PR adds bwd data wmma and xdl kernels to the ck builder, their instance and conv traits as well as tests for the above.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c8a8449eec
commit
5e06874aae
@@ -249,6 +249,26 @@ constexpr Transfer<> Transfer_4x32x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}};
|
||||
|
||||
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_2x2_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
|
||||
|
||||
constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}};
|
||||
|
||||
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{
|
||||
.k1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
@@ -283,6 +303,13 @@ constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{
|
||||
constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{
|
||||
.k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
|
||||
|
||||
constexpr GridwiseWmmaGemmABK1 GemmParamsABK1_Wmma_16x16_2x1_per_wave{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_wmma = 16,
|
||||
.n_per_wmma = 16,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1};
|
||||
|
||||
constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user