diff --git a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp index 1fffea1816..d0be065fc9 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp @@ -4,15 +4,10 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { -enum struct GemmLoopOrder -{ - KMN = 0, - MNK = 1, -}; - // Problem Description for BlockGemm template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1, + GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN> struct GemmPipelineProblemBase { using Traits = remove_cvref_t; @@ -45,9 +46,10 @@ struct GemmPipelineProblemBase static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadK = Traits::kPadK; - static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; - static constexpr auto Scheduler = GemmPipelineScheduler::Default; - static constexpr index_t VectorLoadSize = Traits::_VectorSize; + static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; + static constexpr auto Scheduler = GemmPipelineScheduler::Default; + static constexpr index_t VectorLoadSize = Traits::_VectorSize; + static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; // In the base situation, the Preshuffle setting should be false. static constexpr bool Preshuffle = false; @@ -167,10 +169,11 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1, + GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN> using GemmPipelineProblem = GemmPipelineProblemBase; + VectorSizeB_, + BlockGemmLoopOrder_>; template + GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, + bool HasHotLoop_ = true, + TailNumber TailNum_ = TailNumber::Full, + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1, + GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN> struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; @@ -224,8 +229,9 @@ struct UniversalGemmPipelineProblem static constexpr auto Scheduler = Scheduler_; static constexpr bool Preshuffle = Traits::Preshuffle; - static constexpr index_t VectorSizeA = VectorSizeA_; - static constexpr index_t VectorSizeB = VectorSizeB_; + static constexpr index_t VectorSizeA = VectorSizeA_; + static constexpr index_t VectorSizeB = VectorSizeB_; + static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_; static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_;