From 2117e76277b0cbd344e227bb368c61541262bb18 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 14 Oct 2024 13:59:26 +0800 Subject: [PATCH] decouple the calling from gemm_pipeline (#1571) * decouple the calling from gemm_pipeline * clang format [ROCm/composable_kernel commit: 35c1777d59d89ccab1b25391daf3836af5a75522] --- ...block_fmha_bwd_pipeline_default_policy.hpp | 118 +++++++----------- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 71 +++++------ 2 files changed, 74 insertions(+), 115 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index e1f05d39db..0afad0446c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -5,9 +5,8 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" -#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" @@ -27,20 +26,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { using GemmProblem = - GemmPipelineProblem, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher< typename Problem::QDataType, @@ -66,20 +60,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() { using GemmProblem = - GemmPipelineProblem, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher, - typename Problem::BlockFmhaShape::Gemm2BlockWarps, - typename Problem::BlockFmhaShape::Gemm2WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm2BlockWarps, + typename Problem::BlockFmhaShape::Gemm2WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher< typename Problem::OGradDataType, @@ -143,20 +127,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() { using GemmProblem = - GemmPipelineProblem, - typename Problem::BlockFmhaShape::Gemm3BlockWarps, - typename Problem::BlockFmhaShape::Gemm3WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm3BlockWarps, + typename Problem::BlockFmhaShape::Gemm3WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher, - typename Problem::BlockFmhaShape::Gemm4BlockWarps, - typename Problem::BlockFmhaShape::Gemm4WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm4BlockWarps, + typename Problem::BlockFmhaShape::Gemm4WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { using GemmProblem = - GemmPipelineProblem, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && @@ -207,20 +202,15 @@ struct BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { using GemmProblem = - GemmPipelineProblem, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && @@ -968,20 +958,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>, - TileGemmTraits>; + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; auto warp_gemm = [&]() { if constexpr(std::is_same_v &&