From e21ce62e53ff6560add97929bf3b2dfb73f2f5cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Mon, 15 Sep 2025 12:56:03 +0000 Subject: [PATCH] Change example to match optimally depthwise convolution with merged groups. --- .../grouped_convolution_backward_weight.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp index 3d0d504427..5775acaf0d 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp @@ -29,16 +29,19 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, { constexpr int kBlockPerCu = 1; - constexpr ck_tile::index_t M_Tile = 64; - constexpr ck_tile::index_t N_Tile = 64; + // Block tile: + // Note that we must satisfy + // - MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock + constexpr ck_tile::index_t M_Tile = 8; //64 + constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t M_Warp_Tile = 4; // 32 + constexpr ck_tile::index_t N_Warp_Tile = 64; // 32 constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr ck_tile::index_t VectorSizeA = 8;