diff --git a/example/ck_tile/20_grouped_convolution/gemm_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp similarity index 85% rename from example/ck_tile/20_grouped_convolution/gemm_configs.hpp rename to example/ck_tile/20_grouped_convolution/conv_configs.hpp index 77e1c3af1a..1be6080383 100644 --- a/example/ck_tile/20_grouped_convolution/gemm_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -17,7 +17,7 @@ #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 -struct GemmConfigBase +struct ConvConfigBase { static constexpr bool kPadM = true; static constexpr bool kPadN = true; @@ -29,6 +29,10 @@ struct GemmConfigBase static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; + static constexpr ck_tile::index_t VectorSizeA = 4; + static constexpr ck_tile::index_t VectorSizeB = 8; + static constexpr ck_tile::index_t VectorSizeC = 8; + static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; @@ -37,10 +41,12 @@ struct GemmConfigBase static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool TiledMMAPermuteN = false; + + static constexpr ck_tile::index_t NumGroupsToMerge = 1; }; template -struct GemmConfigMemoryInterwave : public GemmConfigBase +struct ConvConfigMemoryInterwave : public ConvConfigBase { // Memory friendly for Interwave scheduler static constexpr ck_tile::index_t M_Tile = 128; @@ -61,7 +67,7 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase }; template -struct GemmConfigMemoryIntrawave : public GemmConfigBase +struct ConvConfigMemoryIntrawave : public ConvConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 32; @@ -80,7 +86,7 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase }; template -struct GemmConfigComputeV3 : public GemmConfigBase +struct ConvConfigComputeV3 : public ConvConfigBase { // Compute V3 only support Intrawave scheduler static constexpr ck_tile::index_t M_Tile = 16; @@ -100,7 +106,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase }; template -struct GemmConfigComputeV3_1 : public GemmConfigBase +struct ConvConfigComputeV3_1 : public ConvConfigBase { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; @@ -119,7 +125,7 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase }; template -struct GemmConfigComputeV3_2 : public GemmConfigBase +struct ConvConfigComputeV3_2 : public ConvConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -140,7 +146,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase }; template -struct GemmConfigComputeV3_WMMA : public GemmConfigBase +struct ConvConfigComputeV3_WMMA : public ConvConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -161,7 +167,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase }; template -struct GemmConfigComputeV4 : public GemmConfigBase +struct ConvConfigComputeV4 : public ConvConfigBase { // Compute V4 only support Intrawave scheduler // Using the ping pong reader in the lds level @@ -182,7 +188,7 @@ struct GemmConfigComputeV4 : public GemmConfigBase }; template -struct GemmConfigComputeV4_1 : public GemmConfigBase +struct ConvConfigComputeV4_1 : public ConvConfigBase { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; @@ -201,7 +207,7 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase }; template -struct GemmConfigComputeV5 : public GemmConfigBase +struct ConvConfigComputeV5 : public ConvConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -220,6 +226,31 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; +template +struct ConvConfigComputeV3_merged_groups : public ConvConfigBase +{ + static constexpr ck_tile::index_t VectorSizeA = 4; + static constexpr ck_tile::index_t VectorSizeB = 8; + static constexpr ck_tile::index_t VectorSizeC = 8; + + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr ck_tile::index_t NumGroupsToMerge = 2; +}; + template struct ConvTypeConfig; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp index 6f3bedc32a..ad593b1418 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp @@ -51,8 +51,8 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[]) int main(int argc, char* argv[]) { #if CK_TILE_USE_WMMA - return !run_grouped_conv_bwd_data_example(argc, argv); + return !run_grouped_conv_bwd_data_example(argc, argv); #else - return !run_grouped_conv_bwd_data_example(argc, argv); + return !run_grouped_conv_bwd_data_example(argc, argv); #endif } diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp index 159d19fdcd..695adf01bc 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp @@ -14,7 +14,7 @@ #include "grouped_convolution_backward_weight_invoker.hpp" #include "run_grouped_convolution_bwd_weight_example.inc" -template