mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Ck tile gemm example (#1488)
* Checkpoint: Finished with the tile example & kernel verification, working on the different matrix layout * Finished the Matrix Layout feature set up. Note: Need to modify the inner block to solve the shuffle problem in the future. * Fix: Clang Format, API fixed from fmha * fix with better naming convention * revert back the pipeline code of fmha * Fixed: Addressed the comments and merge the GEMM shape of GEMM Operator and FMHA Operator to one. * clang format with the reference_gemm file * convert the clang format with the remod.py * Changed the format and variable name of the kernel gemm_shape and partitioner --------- Co-authored-by: thomasning <thomasning@banff-cyxtera-s70-4.ctr.dcgpu>
This commit is contained in:
@@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem =
|
||||
BlockGemmPipelineProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>>;
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
|
||||
Problem::BlockFmhaShape::BlockTile::kN0,
|
||||
Problem::BlockFmhaShape::BlockTile::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::QDataType,
|
||||
@@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem =
|
||||
BlockGemmPipelineProblem<typename Problem::GemmDataType,
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kVHeaddim,
|
||||
Problem::BlockFmhaShape::kK1>>;
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kN0,
|
||||
Problem::BlockFmhaShape::BlockTile::kVHeaddim,
|
||||
Problem::BlockFmhaShape::BlockTile::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
@@ -88,14 +90,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem =
|
||||
BlockGemmPipelineProblem<typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK2>>;
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
|
||||
Problem::BlockFmhaShape::BlockTile::kN0,
|
||||
Problem::BlockFmhaShape::BlockTile::kK2>,
|
||||
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::OGradDataType,
|
||||
@@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem =
|
||||
BlockGemmPipelineProblem<typename Problem::GemmDataType,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kK3>>;
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kN0,
|
||||
Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::BlockTile::kK3>,
|
||||
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
@@ -151,14 +155,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem =
|
||||
BlockGemmPipelineProblem<typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kK4>>;
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
|
||||
Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::BlockTile::kK4>,
|
||||
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
|
||||
@@ -75,14 +75,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem =
|
||||
BlockGemmPipelineProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>>;
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
@@ -198,14 +199,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem =
|
||||
BlockGemmPipelineProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>>;
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
@@ -952,14 +954,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem =
|
||||
BlockGemmPipelineProblem<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>>;
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
|
||||
Reference in New Issue
Block a user