Change example to match optimally depthwise convolution with merged groups.

This commit is contained in:
Ville Pietilä
2025-09-15 12:56:03 +00:00
parent ff9732b937
commit e21ce62e53

View File

@@ -29,16 +29,19 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
{
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t M_Tile = 64;
constexpr ck_tile::index_t N_Tile = 64;
// Block tile: <MPerBlock, NPerBlock, KPerBlock>
// Note that we must satisfy
// - MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock
constexpr ck_tile::index_t M_Tile = 8; //64
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t M_Warp_Tile = 4; // 32
constexpr ck_tile::index_t N_Warp_Tile = 64; // 32
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr ck_tile::index_t VectorSizeA = 8;