From 89934275f48b3540df577e6187849fe935ff4b84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 2 Jan 2026 07:07:08 -0500 Subject: [PATCH] Fix WMMA bwd weight tests. --- experimental/builder/test/CMakeLists.txt | 11 +++++++++-- .../ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp | 6 +++--- .../builder/test/utils/ckb_conv_test_configs.hpp | 6 ++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index f72a72f5fb..3d6e448d84 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -164,15 +164,22 @@ add_ck_builder_test(test_ckb_build_fwd_instances ) target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) -add_ck_builder_test(test_ckb_build_bwd_weight_instances +set(BWD_WEIGHT_TESTS conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp conv/ck/test_ckb_conv_bwd_weight_dl.cpp - conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +) + +if (CK_USE_WMMA) + list(APPEND BWD_WEIGHT_TESTS + conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp ) +endif() + +add_ck_builder_test(test_ckb_build_bwd_weight_instances ${BWD_WEIGHT_TESTS}) target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility) add_ck_builder_test(test_ckb_build_bwd_data_instances diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp index 268bdee5be..4a1a60e852 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp @@ -22,10 +22,10 @@ constexpr auto SIGNATURE = constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{} .with_thread_block(cku::ThreadBlock_64_32x32x32) - .with_gemm_config(cku::GemmParams_Wmma_2x1_per_wave) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) - .with_block_gemm(cku::BlockGemmDesc_v2_intrawave) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) .with_transpose_params(4,4); using Builder = ckb::ConvBuilder; @@ -40,5 +40,5 @@ TEST(BwdWeight_1DBf16_Wmma_CShuffle_V3, Create) "NGCW,GKXC,NGKW", "PassThrough,PassThrough,PassThrough", "Intrawave", - "v2"}); + "v1"}); } diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 299aa56fac..e1f5a34e20 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -281,6 +281,12 @@ constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{.k1 = 8, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; +constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{.k1 = 8, + .m_per_wmma = 16, + .n_per_wmma = 16, + .m_wmma_per_wave = 2, + .n_wmma_per_wave = 1}; + constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}};