Ck tile gemm cshuffle & CK Tile GEMM restructure (#1535)

* ake the cshuffle compilable

* modify Mhe reference on gpu and cpu. Correaccess of cshuffle

* fix the cpu reference code

* Complete the in tile shuffle logic

* restructure the kernel template input

* change the naming pattern of ck_tile gemm pipeline

* Re-format files using remod.py

* Solve the fmha conflict with gemm

* Comment Addressed from Carlus

---------

Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
Thomas Ning
2024-10-10 03:02:22 -07:00
committed by GitHub
parent 2e1165c1a7
commit 6f27bc9872
15 changed files with 447 additions and 142 deletions

View File

@@ -41,18 +41,39 @@ template <typename LayoutA,
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr bool kPadA = true;
constexpr bool kPadB = true;
constexpr bool kTilePermute = false;
constexpr int kBlockPerCu = 1;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadA,
kPadB,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel =
ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, LayoutA, LayoutB, LayoutC>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using CodegenPipelineProblem = ck_tile::BlockGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
kPadA,
kPadB,
kPadC>;
using CodegenGemmTraits = ck_tile::
TileGemmTraits<kPadA, kPadB, kPadC, matrix_a_layout, matrix_b_layout, matrix_c_layout>;
using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
invoke_gemm<ck_tile::half_t,
matrix_a_layout,
@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions);
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes());
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>(
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c);
c_buf.FromDevice(c_host_gpu_ref.data());