diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp index 3d0d504427..5775acaf0d 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp @@ -29,16 +29,19 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, { constexpr int kBlockPerCu = 1; - constexpr ck_tile::index_t M_Tile = 64; - constexpr ck_tile::index_t N_Tile = 64; + // Block tile: + // Note that we must satisfy + // - MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock + constexpr ck_tile::index_t M_Tile = 8; //64 + constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t M_Warp_Tile = 4; // 32 + constexpr ck_tile::index_t N_Warp_Tile = 64; // 32 constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr ck_tile::index_t VectorSizeA = 8;