[CK-tile] unhardcode the number of LDS banks from universal gemm policy (#3130)

Fixes LDS bank conflicts on gfx950 for universal gemm v3 pipeline

Replaces hardcoded LDS layer calculations with dynamic computation using the new architecture helpers

Adds architecture-specific helper function get_n_lds_banks()

Changes function attributes from CK_TILE_HOST_DEVICE to CK_TILE_DEVICE in universal gemm policy
This commit is contained in:
Max Podkorytov
2025-10-31 11:58:11 -07:00
committed by GitHub
parent 4ebc48a3cd
commit 04efd282cf
5 changed files with 65 additions and 18 deletions

View File

@@ -319,22 +319,67 @@ struct gfx9_t
struct gfx950_t
{
};
struct gfx103_t
{
};
struct gfx11_t
{
};
struct gfx12_t
{
};
struct gfx_invalid_t
{
};
CK_TILE_DEVICE static constexpr auto get_device_arch()
{
// FIXME(0): on all devices except gfx11 it returns gfx12_t
// FIXME(1): during the host compilation pass it returns gfx12_t
#if defined(__gfx11__)
return gfx11_t{};
#else // if defined(__gfx12__)
#else
return gfx12_t{};
#endif
}
CK_TILE_DEVICE static constexpr auto get_n_words_per_128b() { return 4; }
namespace detail {
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx9_t) { return 32; }
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx103_t) { return 32; }
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx11_t) { return 32; }
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx12_t) { return 32; }
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; }
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx_invalid_t) { return 0; }
CK_TILE_DEVICE static constexpr auto arch_tag_dispatch()
{
#if defined(__gfx103__)
return gfx103_t{};
#elif defined(__gfx11__)
return gfx11_t{};
#elif defined(__gfx12__)
return gfx12_t{};
#elif defined(__gfx950__)
return gfx950_t{};
#elif defined(__gfx9__)
return gfx9_t{};
#else
return gfx_invalid_t{};
#endif
}
} // namespace detail
CK_TILE_DEVICE static constexpr auto get_n_lds_banks()
{
return detail::get_n_lds_banks(detail::arch_tag_dispatch());
}
enum LLVMSchedGroupMask : int32_t
{
NONE = 0,