From df8769d3c88d3e3c5b7344cff62f3bd3b77a2254 Mon Sep 17 00:00:00 2001 From: Dan Yao Date: Wed, 11 Sep 2024 01:45:32 +0800 Subject: [PATCH] [CK_TILE] FA bwd repair (#1502) * fix fa bwd * revert kernelBlockSize in gemm_kernel.hpp [ROCm/composable_kernel commit: d09572e8c278ed467433c4705e1cf664c301bf1e] --- ...block_fmha_bwd_pipeline_default_policy.hpp | 30 +++++++++---------- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 2 +- ...lock_gemm_pipeline_agmem_bgmem_creg_v1.hpp | 2 +- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 12 ++++---- ...lock_gemm_pipeline_agmem_bgmem_creg_v2.hpp | 2 +- .../pipeline/block_gemm_pipeline_problem.hpp | 8 ++--- 6 files changed, 28 insertions(+), 28 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index e6a71f210e..9e1ab81125 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -29,9 +29,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::QDataType, typename Problem::KDataType, typename Problem::AccDataType, - TileGemmShape, + TileGemmShape, typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; @@ -62,9 +62,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::GemmDataType, typename Problem::OGradDataType, typename Problem::AccDataType, - TileGemmShape, + TileGemmShape, typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1WarpTile>>; @@ -94,9 +94,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::OGradDataType, typename Problem::VDataType, typename Problem::AccDataType, - TileGemmShape, + TileGemmShape, typename Problem::BlockFmhaShape::Gemm2BlockWarps, typename Problem::BlockFmhaShape::Gemm2WarpTile>>; @@ -127,9 +127,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::GemmDataType, typename Problem::QDataType, typename Problem::AccDataType, - TileGemmShape, + TileGemmShape, typename Problem::BlockFmhaShape::Gemm3BlockWarps, typename Problem::BlockFmhaShape::Gemm3WarpTile>>; @@ -159,9 +159,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::GemmDataType, typename Problem::KDataType, typename Problem::AccDataType, - TileGemmShape, + TileGemmShape, typename Problem::BlockFmhaShape::Gemm4BlockWarps, typename Problem::BlockFmhaShape::Gemm4WarpTile>>; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 8cdf9b1005..01d8f23288 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -25,7 +25,7 @@ struct GemmKernel using LayoutA = remove_cvref_t; using LayoutB = remove_cvref_t; using LayoutC = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::KernelBlockSize; + static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp index a90178ddb1..0557143bc8 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t KernelBlockSize = Problem::KernelBlockSize; + static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kNPerBlock = BlockGemmShape::kN; diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index c7f292d2b5..3048adad67 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -195,7 +195,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy { using ADataType = remove_cvref_t; - constexpr index_t KernelBlockSize = Problem::KernelBlockSize; + constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; @@ -204,7 +204,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr index_t K0 = kKPerBlock / K1; constexpr index_t M2 = get_warp_size() / K0; #if 1 // coalesce reading for each blocks - constexpr index_t M1 = KernelBlockSize / get_warp_size(); + constexpr index_t M1 = kBlockSize / get_warp_size(); static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); constexpr index_t M0 = kMPerBlock / (M2 * M1); @@ -217,7 +217,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy sequence<1, 2>, sequence<0, 1>>{}); #else // coalesce reading for each warps - constexpr index_t M0 = KernelBlockSize / get_warp_size(); + constexpr index_t M0 = kBlockSize / get_warp_size(); constexpr index_t M1 = kMPerBlock / (M2 * M0); return make_static_tile_distribution( @@ -235,7 +235,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy { using BDataType = remove_cvref_t; - constexpr index_t KernelBlockSize = Problem::KernelBlockSize; + constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; @@ -244,7 +244,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr index_t K0 = kKPerBlock / K1; constexpr index_t N2 = get_warp_size() / K0; #if 1 // coalesce reading for each blocks - constexpr index_t N1 = KernelBlockSize / get_warp_size(); + constexpr index_t N1 = kBlockSize / get_warp_size(); static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); constexpr index_t N0 = kNPerBlock / (N2 * N1); @@ -257,7 +257,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy sequence<1, 2>, sequence<0, 1>>{}); #else // coalesce reading for each warps - constexpr index_t N0 = KernelBlockSize / get_warp_size(); + constexpr index_t N0 = kBlockSize / get_warp_size(); constexpr index_t N1 = kNPerBlock / (N2 * N0); return make_static_tile_distribution( diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp index deb9b07f16..ab5fe79114 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2 using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t KernelBlockSize = Problem::KernelBlockSize; + static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kNPerBlock = BlockGemmShape::kN; diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp index dda6022dc8..acb94f8a68 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp @@ -23,10 +23,10 @@ struct BlockGemmPipelineProblem using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t KernelBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - static constexpr bool kPadA = kPadA_; - static constexpr bool kPadB = kPadB_; - static constexpr bool kPadC = kPadC_; + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); + static constexpr bool kPadA = kPadA_; + static constexpr bool kPadB = kPadB_; + static constexpr bool kPadC = kPadC_; static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1; static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1;