From 180b726f97c0295d77e5db33ff7661f1e773be7c Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Fri, 7 Jun 2024 17:49:44 +0000 Subject: [PATCH] Fix wrong kBlockSize used in policy --- .../pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index d4adf9792f..775bf82b6c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -991,7 +991,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - constexpr index_t kBlockSize = 256; + constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::kMaxSplits; @@ -1031,7 +1031,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTDramTileDistribution() { - constexpr index_t kBlockSize = 256; + constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size()); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;