diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 2c4385027d..52d8146432 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -39,7 +39,7 @@ CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize() return 1; } else - static_assert(false, "The data type is not supported!"); + return 1; }; template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { + using VDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + // special consideration when shuffling is required before storing V to LDS if constexpr(!Problem::kUseTrLoad) { - using VDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t kMaxVecLoad = Problem::GetVDramTileAccessMaxVectorSize(); + constexpr index_t kMaxVecLoad = detail:: + GetDramTileAccessMaxVectorSize(); constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); // try to avoid writing sub-dword to LDS due to poor performance @@ -176,7 +176,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy } else { - return Problem::GetVDramTileAccessMaxVectorSize(); + return detail:: + GetDramTileAccessMaxVectorSize(); }; } @@ -615,8 +616,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy auto warp_gemm = [&]() { if constexpr((std::is_same_v || - std::is_same_v)&&std:: - is_same_v) + std::is_same_v) && + std::is_same_v) { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); @@ -681,8 +682,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy auto warp_gemm = [&]() { if constexpr((std::is_same_v || - std::is_same_v)&&std:: - is_same_v) + std::is_same_v) && + std::is_same_v) { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});