[CK-tile] unhardcode the number of LDS banks from universal gemm policy (#3130)

Fixes LDS bank conflicts on gfx950 for universal gemm v3 pipeline

Replaces hardcoded LDS layer calculations with dynamic computation using the new architecture helpers

Adds architecture-specific helper function get_n_lds_banks()

Changes function attributes from CK_TILE_HOST_DEVICE to CK_TILE_DEVICE in universal gemm policy
This commit is contained in:
Max Podkorytov
2025-10-31 11:58:11 -07:00
committed by GitHub
parent 4ebc48a3cd
commit 04efd282cf
5 changed files with 65 additions and 18 deletions

View File

@@ -442,7 +442,7 @@ struct BlockFmhaV3PipelineDefaultPolicy
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);

View File

@@ -140,7 +140,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers<Problem>();
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);

View File

@@ -465,7 +465,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
@@ -620,7 +620,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);

View File

@@ -71,7 +71,7 @@ struct UniversalGemmBasePolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
@@ -94,7 +94,7 @@ struct UniversalGemmBasePolicy
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr auto MLdsLayer =
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
@@ -141,7 +141,7 @@ struct UniversalGemmBasePolicy
* @return B tensor LDS block descriptor.
*/
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
@@ -166,7 +166,7 @@ struct UniversalGemmBasePolicy
constexpr auto BK0 = number<KPerBlock / KPack>{};
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr auto NLdsLayer =
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(
@@ -658,25 +658,27 @@ struct UniversalGemmBasePolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr auto a_lds_desc = MakeALdsBlockDescriptor<Problem>();
constexpr index_t smem_size_a = integer_least_multiple(
sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16);
constexpr index_t smem_size_a =
integer_least_multiple(sizeof(typename Problem::ADataType) *
Problem::BlockGemmShape::kM * Problem::BlockGemmShape::kK,
16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr auto b_lds_desc = MakeBLdsBlockDescriptor<Problem>();
constexpr index_t smem_size_b = integer_least_multiple(
sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16);
constexpr index_t smem_size_b =
integer_least_multiple(sizeof(typename Problem::BDataType) *
Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK,
16);
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();