[CK-tile] unhardcode the number of LDS banks from universal gemm policy (#3130)

Fixes LDS bank conflicts on gfx950 for universal gemm v3 pipeline

Replaces hardcoded LDS layer calculations with dynamic computation using the new architecture helpers

Adds architecture-specific helper function get_n_lds_banks()

Changes function attributes from CK_TILE_HOST_DEVICE to CK_TILE_DEVICE in universal gemm policy
This commit is contained in:
Max Podkorytov
2025-10-31 11:58:11 -07:00
committed by GitHub
parent 4ebc48a3cd
commit 04efd282cf
5 changed files with 65 additions and 18 deletions

View File

@@ -442,7 +442,7 @@ struct BlockFmhaV3PipelineDefaultPolicy
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);

View File

@@ -140,7 +140,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers<Problem>();
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);

View File

@@ -465,7 +465,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
@@ -620,7 +620,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);