Fix WMMA bwd weight tests.

This commit is contained in:
Ville Pietilä
2026-01-02 07:07:08 -05:00
parent 2e43e16e47
commit 89934275f4
3 changed files with 18 additions and 5 deletions

View File

@@ -164,15 +164,22 @@ add_ck_builder_test(test_ckb_build_fwd_instances
)
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
add_ck_builder_test(test_ckb_build_bwd_weight_instances
set(BWD_WEIGHT_TESTS
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp
conv/ck/test_ckb_conv_bwd_weight_dl.cpp
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
)
if (CK_USE_WMMA)
list(APPEND BWD_WEIGHT_TESTS
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp
)
endif()
add_ck_builder_test(test_ckb_build_bwd_weight_instances ${BWD_WEIGHT_TESTS})
target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility)
add_ck_builder_test(test_ckb_build_bwd_data_instances

View File

@@ -22,10 +22,10 @@ constexpr auto SIGNATURE =
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::GemmParams_Wmma_2x1_per_wave)
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
.with_transpose_params(4,4);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
@@ -40,5 +40,5 @@ TEST(BwdWeight_1DBf16_Wmma_CShuffle_V3, Create)
"NGCW,GKXC,NGKW",
"PassThrough,PassThrough,PassThrough",
"Intrawave",
"v2"});
"v1"});
}

View File

@@ -281,6 +281,12 @@ constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{.k1 = 8,
.m_wmma_per_wave = 2,
.n_wmma_per_wave = 1};
constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{.k1 = 8,
.m_per_wmma = 16,
.n_per_wmma = 16,
.m_wmma_per_wave = 2,
.n_wmma_per_wave = 1};
constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};