Extract common logics

This commit is contained in:
PoYen, Chen
2024-06-26 18:02:28 +00:00
parent 8fb567c286
commit c40c1daff0

View File

@@ -90,11 +90,11 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVnewDramTileDistribution()
{
using VLayout = remove_cvref_t<typename Problem::VLayout>;
using VLayout = remove_cvref_t<typename Problem::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
static_assert(!std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t kNPerBlock = Problem::kTileSizeDv;
@@ -105,7 +105,6 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
static_assert(N0 != 0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
@@ -117,8 +116,6 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
}
else
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kNPerBlock = Problem::kTileSizeDv;
constexpr index_t kKPerBlock = Problem::kTileSizeSk;