From 976e32ccfa5b468471d7e2751ba12fa855b688cb Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Tue, 13 May 2025 14:38:36 +0000 Subject: [PATCH] custom vector size --- .../grouped_convolution_forward.cpp | 26 ++++++++++--- .../ops/epilogue/cshuffle_epilogue.hpp | 14 ++++++- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 4 +- .../gemm/pipeline/gemm_pipeline_problem.hpp | 38 ++++++++++++++++--- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 18 +++++---- 5 files changed, 77 insertions(+), 23 deletions(-) diff --git a/example/ck_tile/37_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/37_grouped_convolution/grouped_convolution_forward.cpp index 0cc5f1cc83..a2c5da56b3 100644 --- a/example/ck_tile/37_grouped_convolution/grouped_convolution_forward.cpp +++ b/example/ck_tile/37_grouped_convolution/grouped_convolution_forward.cpp @@ -38,6 +38,10 @@ float grouped_conv_fwd_calc(const ck_tile::GroupedConvHostArgs& args, constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; + constexpr ck_tile::index_t VectorSizeA = 8; + constexpr ck_tile::index_t VectorSizeB = 8; + constexpr ck_tile::index_t VectorSizeC = 8; + // Implicit GEMM Traits using CodegenShape = ck_tile::TileGemmShape, @@ -47,9 +51,16 @@ float grouped_conv_fwd_calc(const ck_tile::GroupedConvHostArgs& args, using TilePartitioner = ck_tile::GemmTile1DPartitioner; using CodegenTraits = ck_tile::GroupedConvImplicitGemmTraits; - using CodegenPipelineProblem = ck_tile:: - GemmPipelineProblem; - using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; @@ -69,7 +80,9 @@ float grouped_conv_fwd_calc(const ck_tile::GroupedConvHostArgs& args, N_Warp_Tile, K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation>>; + memory_operation, + true, + VectorSizeC>>; constexpr auto ConvSpec = ck_tile::ConvolutionForwardSpecialization::Default; @@ -99,7 +112,10 @@ float grouped_conv_fwd_calc(const ck_tile::GroupedConvHostArgs& args, << "pipeline: " << CodegenPipeline::GetName() << '\n' << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; } float ave_time = ck_tile::launch_kernel( diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 9b8dde1905..36c84bad87 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -23,7 +23,9 @@ template + memory_operation_enum MemoryOperation_, + bool FixedVectorSize_ = false, + index_t VectorSizeC_ = 1> struct CShuffleEpilogueProblem { using ADataType = remove_cvref_t; @@ -41,6 +43,8 @@ struct CShuffleEpilogueProblem static constexpr index_t kKPerXdl = kKPerXdl_; static constexpr index_t isCTransposed = isCTransposed_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; + static constexpr bool FixedVectorSize = FixedVectorSize_; + static constexpr index_t VectorSizeC = VectorSizeC_; }; template @@ -65,6 +69,8 @@ struct CShuffleEpilogue static constexpr index_t kNPerXdl = Problem::kNPerXdl; static constexpr index_t kKPerXdl = Problem::kKPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr bool FixedVectorSize = Problem::FixedVectorSize; + static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t kMPerIteration = kMPerXdl * kMWave; static constexpr index_t kNPerIteration = kNPerXdl * kNWave; @@ -91,8 +97,12 @@ struct CShuffleEpilogue */ CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { + if constexpr(FixedVectorSize) + { + return VectorSizeC; + } constexpr index_t MaxVectorStoreSize = 16; - return MaxVectorStoreSize / sizeof(ODataType); + return static_cast(MaxVectorStoreSize / sizeof(ODataType)); } template diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 6bb14af9e6..0f7f6369f0 100755 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -121,7 +121,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy if constexpr(std::is_same_v) { - constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M1 = Problem::VectorSizeA; constexpr index_t M0 = MPerBlock / M1; constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % M1 == 0); @@ -211,7 +211,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy if constexpr(std::is_same_v) { - constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N1 = Problem::VectorSizeB; constexpr index_t N0 = NPerBlock / N1; constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % N1 == 0); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 0b38e7789e..c1e15f892b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -14,7 +14,10 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> struct GemmPipelineProblemBase { using Traits = remove_cvref_t; @@ -24,6 +27,8 @@ struct GemmPipelineProblemBase using CDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; + static constexpr bool FixedVectorSize = FixedVectorSize_; + using BlockGemmShape = remove_cvref_t; using ALayout = remove_cvref_t; @@ -114,7 +119,11 @@ struct GemmPipelineProblemBase } static constexpr index_t VectorSizeA = []() { - if constexpr(std::is_same_v) + if constexpr(FixedVectorSize) + { + return VectorSizeA_; + } + else if constexpr(std::is_same_v) { return kPadK ? 1 : GetAlignmentA(); } @@ -125,7 +134,11 @@ struct GemmPipelineProblemBase }(); static constexpr index_t VectorSizeB = []() { - if constexpr(std::is_same_v) + if constexpr(FixedVectorSize) + { + return VectorSizeB_; + } + else if constexpr(std::is_same_v) { return kPadN ? 1 : GetAlignmentB(); } @@ -152,13 +165,19 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> using GemmPipelineProblem = GemmPipelineProblemBase; + ComputeDataType_, + FixedVectorSize_, + VectorSizeA_, + VectorSizeB_>; template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; @@ -178,6 +200,10 @@ struct UniversalGemmPipelineProblem using CDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; + static constexpr bool FixedVectorSize = FixedVectorSize_; + static constexpr bool VectorSizeA = VectorSizeA_; + static constexpr bool VectorSizeB = VectorSizeB_; + using BlockGemmShape = remove_cvref_t; using ALayout = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 6890cf2f64..c189230643 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -426,10 +426,11 @@ struct UniversalGemmBasePolicy { using ALayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeA(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA(); // Tile: MPerBlock X KPerBlock if constexpr(std::is_same_v) @@ -458,10 +459,11 @@ struct UniversalGemmBasePolicy { using BLayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeB(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v)