[CK_TILE] FMHA Support hdim_v to as a Multiple of 32 (#2114)

* 160+192

* Add splitkv d160

* cleanup

* fix

* Add change log

* Fix CHANGELOG

* Use static_cast

* Update ignored instance

---------

Co-authored-by: asleepzzz <hanwen.chang@amd.com>

[ROCm/composable_kernel commit: b8212864cf]
This commit is contained in:
Yi DING
2025-06-24 01:33:31 +08:00
committed by GitHub
parent 470608bdcb
commit cba904aeff
7 changed files with 89 additions and 68 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -129,7 +129,10 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
[&](auto ii) {
return ii == y_dim_vec_in ? static_cast<index_t>(idx_y_start[ii]) + i
: static_cast<index_t>(idx_y_start[ii]);
},
number<NDimY>{});
constexpr auto idx_y_out =

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -787,12 +787,29 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t K3 = total_pixels / N1;
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
if constexpr(get_warp_size() % (K2 * N0) == 0)
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
{
constexpr index_t kNPack = 32;
static_assert(kNPerBlock % kNPack == 0);
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t N2 = 2;
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
tuple<sequence<0>, sequence<1, 1>>,
sequence<1, 2, 1>, // N0 K2 N2
sequence<0, 2, 2>>{});
}
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
@@ -860,12 +877,28 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
if constexpr(get_warp_size() % (K2 * N0) == 0)
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
{
constexpr index_t kNPack = 32;
static_assert(kNPerBlock % kNPack == 0);
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t N2 = 2;
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
tuple<sequence<0>, sequence<1, 1>>,
sequence<1, 1, 2>, // N0 K2 <-> N2
sequence<0, 2, 2>>{});
}
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();