From 0bb1420fcf92087442ace07a7d2c8ab08668774d Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Tue, 24 Feb 2026 08:15:39 +0000 Subject: [PATCH] Remove some not very much required interfaces from pipeline problem --- .../pipeline/block_fmha_pipeline_problem.hpp | 2 +- ..._ks_vs_whole_k_prefetch_default_policy.hpp | 25 ++++++++++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 2c4385027d..52d8146432 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -39,7 +39,7 @@ CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize() return 1; } else - static_assert(false, "The data type is not supported!"); + return 1; }; template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { + using VDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + // special consideration when shuffling is required before storing V to LDS if constexpr(!Problem::kUseTrLoad) { - using VDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t kMaxVecLoad = Problem::GetVDramTileAccessMaxVectorSize(); + constexpr index_t kMaxVecLoad = detail:: + GetDramTileAccessMaxVectorSize(); constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); // try to avoid writing sub-dword to LDS due to poor performance @@ -176,7 +176,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy } else { - return Problem::GetVDramTileAccessMaxVectorSize(); + return detail:: + GetDramTileAccessMaxVectorSize(); }; } @@ -615,8 +616,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy auto warp_gemm = [&]() { if constexpr((std::is_same_v || - std::is_same_v)&&std:: - is_same_v) + std::is_same_v) && + std::is_same_v) { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); @@ -681,8 +682,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy auto warp_gemm = [&]() { if constexpr((std::is_same_v || - std::is_same_v)&&std:: - is_same_v) + std::is_same_v) && + std::is_same_v) { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});