Ck tile gemm example (#1488)

* Checkpoint: Finished with the tile example & kernel verification, working on the different matrix layout

* Finished the Matrix Layout feature set up. Note: Need to modify the inner block to solve the shuffle problem in the future.

* Fix: Clang Format, API fixed from fmha

* fix with better naming convention

* revert back the pipeline code of fmha

* Fixed: Addressed the comments and merge the GEMM shape of GEMM Operator and FMHA Operator to one.

* clang format with the reference_gemm file

* convert the clang format with the remod.py

* Changed the format and variable name of the kernel gemm_shape and partitioner

---------

Co-authored-by: thomasning <thomasning@banff-cyxtera-s70-4.ctr.dcgpu>
This commit is contained in:
Thomas Ning
2024-09-07 01:23:32 -07:00
committed by GitHub
parent 8378855361
commit caacd38830
18 changed files with 758 additions and 92 deletions

View File

@@ -4,7 +4,8 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
namespace ck_tile {
@@ -27,9 +28,9 @@ struct BlockGemmARegBGmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize;
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1<
using BlockGemmARegBGmemCRegImpl = BlockGemmARegBGmemCRegV1<
BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
BlockGemmARegBSmemCRegV1DefaultPolicy>;
BlockGemmARegBGmemCRegV1DefaultPolicy>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{
@@ -82,7 +83,7 @@ struct BlockGemmARegBGmemCRegV1
block_sync_lds();
// block GEMM
BlockGemmARegBSmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window);
BlockGemmARegBGmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window);
}
// C = A * B
@@ -128,7 +129,7 @@ struct BlockGemmARegBGmemCRegV1
block_sync_lds();
// block GEMM
return BlockGemmARegBSmemCRegImpl{}(a_block_tensor, b_block_smem_window);
return BlockGemmARegBGmemCRegImpl{}(a_block_tensor, b_block_smem_window);
}
};

View File

@@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
}
}
};