mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Change example to match optimally depthwise convolution with merged groups.
This commit is contained in:
@@ -29,16 +29,19 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
// Block tile: <MPerBlock, NPerBlock, KPerBlock>
|
||||
// Note that we must satisfy
|
||||
// - MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock
|
||||
constexpr ck_tile::index_t M_Tile = 8; //64
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 4; // 32
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 64; // 32
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
|
||||
Reference in New Issue
Block a user