Ck tile gemm example (#1488)

* Checkpoint: Finished with the tile example & kernel verification, working on the different matrix layout

* Finished the Matrix Layout feature set up. Note: Need to modify the inner block to solve the shuffle problem in the future.

* Fix: Clang Format, API fixed from fmha

* fix with better naming convention

* revert back the pipeline code of fmha

* Fixed: Addressed the comments and merge the GEMM shape of GEMM Operator and FMHA Operator to one.

* clang format with the reference_gemm file

* convert the clang format with the remod.py

* Changed the format and variable name of the kernel gemm_shape and partitioner

---------

Co-authored-by: thomasning <thomasning@banff-cyxtera-s70-4.ctr.dcgpu>
This commit is contained in:
Thomas Ning
2024-09-07 01:23:32 -07:00
committed by GitHub
parent 8378855361
commit caacd38830
18 changed files with 758 additions and 92 deletions

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace ck_tile {
@@ -18,12 +19,16 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t AlignmentA = Problem::AlignmentA;
static constexpr index_t AlignmentB = Problem::AlignmentB;
static constexpr index_t AlignmentC = Problem::AlignmentC;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{
return ck_tile::integer_divide_ceil(
@@ -35,6 +40,11 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
@@ -140,8 +150,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
}
index_t iCounter = num_loop - 1;
do
while(iCounter > 0)
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
@@ -167,8 +176,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
store_tile(b_copy_lds_window, b_block_tile_tmp);
iCounter--;
} while(iCounter > 0);
}
// tail
{

View File

@@ -91,6 +91,33 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
return b_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0;
smem_size += smem_size_a + smem_size_b;
return smem_size;
}
#elif 1
// fake XOR
template <typename Problem>
@@ -168,7 +195,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
@@ -177,7 +204,9 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M1 = KernelBlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
@@ -188,7 +217,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
sequence<1, 2>,
sequence<0, 1>>{});
#else // coalesce reading for each warps
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M0 = KernelBlockSize / get_warp_size();
constexpr index_t M1 = kMPerBlock / (M2 * M0);
return make_static_tile_distribution(
@@ -206,7 +235,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
@@ -215,7 +244,9 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N1 = KernelBlockSize / get_warp_size();
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
@@ -226,7 +257,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
sequence<1, 2>,
sequence<0, 1>>{});
#else // coalesce reading for each warps
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N0 = KernelBlockSize / get_warp_size();
constexpr index_t N1 = kNPerBlock / (N2 * N0);
return make_static_tile_distribution(

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace ck_tile {
@@ -18,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;

View File

@@ -5,13 +5,17 @@
#include "ck_tile/core.hpp"
#define VectorLoadSize 16
namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
typename BlockGemmShape_,
bool kPadA_ = false,
bool kPadB_ = false,
bool kPadC_ = false>
struct BlockGemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
@@ -19,7 +23,14 @@ struct BlockGemmPipelineProblem
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t KernelBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;
static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1;
static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1;
static constexpr index_t AlignmentC = kPadC ? VectorLoadSize / sizeof(CDataType) : 1;
};
} // namespace ck_tile

View File

@@ -7,12 +7,18 @@
namespace ck_tile {
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
template <typename BlockTile_, typename BlockWarps_, typename WarpTile_>
struct TileGemmShape
{
static constexpr index_t kM = kMPerTile;
static constexpr index_t kN = kNPerTile;
static constexpr index_t kK = kKPerTile;
using BlockTile = remove_cvref_t<BlockTile_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t kM = BlockTile::at(number<0>{});
static constexpr index_t kN = BlockTile::at(number<1>{});
static constexpr index_t kK = BlockTile::at(number<2>{});
};
} // namespace ck_tile