mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Fix WMMA bwd weight tests.
This commit is contained in:
@@ -164,15 +164,22 @@ add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
)
|
||||
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
|
||||
|
||||
add_ck_builder_test(test_ckb_build_bwd_weight_instances
|
||||
set(BWD_WEIGHT_TESTS
|
||||
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_dl.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp
|
||||
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
|
||||
)
|
||||
|
||||
if (CK_USE_WMMA)
|
||||
list(APPEND BWD_WEIGHT_TESTS
|
||||
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
add_ck_builder_test(test_ckb_build_bwd_weight_instances ${BWD_WEIGHT_TESTS})
|
||||
target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility)
|
||||
|
||||
add_ck_builder_test(test_ckb_build_bwd_data_instances
|
||||
|
||||
@@ -22,10 +22,10 @@ constexpr auto SIGNATURE =
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_2x1_per_wave)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
|
||||
.with_transpose_params(4,4);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
@@ -40,5 +40,5 @@ TEST(BwdWeight_1DBf16_Wmma_CShuffle_V3, Create)
|
||||
"NGCW,GKXC,NGKW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave",
|
||||
"v2"});
|
||||
"v1"});
|
||||
}
|
||||
|
||||
@@ -281,6 +281,12 @@ constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{.k1 = 8,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1};
|
||||
|
||||
constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{.k1 = 8,
|
||||
.m_per_wmma = 16,
|
||||
.n_per_wmma = 16,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1};
|
||||
|
||||
constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user