CK Tile GEMM CICD fixed & register block method refactor (#1776)

* refactor the block_gemm_areg_breg_creg_v1 and add the v2 policy with 2x2 warp gemm

* Finished the 2x2 warp gemm policy and the block selection mechanism

* Clang format

* address poyen's comment

* Address feedbacks

* Fixed the compilation issue

* Change the function name
This commit is contained in:
Thomas Ning
2025-01-12 21:10:44 -08:00
committed by GitHub
parent 0b8f117f1a
commit 5d671a5fc4
8 changed files with 109 additions and 107 deletions

View File

@@ -104,9 +104,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;

View File

@@ -23,6 +23,8 @@ struct GemmPipelineAGmemBGmemCRegV1
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
@@ -126,7 +128,7 @@ struct GemmPipelineAGmemBGmemCRegV1
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
auto block_gemm = Policy::template GetBlockGemm<Problem>();
auto block_gemm = BlockGemm();
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};

View File

@@ -12,8 +12,11 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = false;
static constexpr bool TransposeC = true;
#if 0
// 2d
@@ -491,10 +494,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;

View File

@@ -11,7 +11,6 @@ namespace ck_tile {
// UniversalGemm Policy
struct UniversalGemmPipelineAgBgCrPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};