mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK-tile] unhardcode the number of LDS banks from universal gemm policy (#3130)
Fixes LDS bank conflicts on gfx950 for universal gemm v3 pipeline Replaces hardcoded LDS layer calculations with dynamic computation using the new architecture helpers Adds architecture-specific helper function get_n_lds_banks() Changes function attributes from CK_TILE_HOST_DEVICE to CK_TILE_DEVICE in universal gemm policy
This commit is contained in:
@@ -319,22 +319,67 @@ struct gfx9_t
|
||||
struct gfx950_t
|
||||
{
|
||||
};
|
||||
struct gfx103_t
|
||||
{
|
||||
};
|
||||
struct gfx11_t
|
||||
{
|
||||
};
|
||||
struct gfx12_t
|
||||
{
|
||||
};
|
||||
struct gfx_invalid_t
|
||||
{
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_device_arch()
|
||||
{
|
||||
// FIXME(0): on all devices except gfx11 it returns gfx12_t
|
||||
// FIXME(1): during the host compilation pass it returns gfx12_t
|
||||
#if defined(__gfx11__)
|
||||
return gfx11_t{};
|
||||
#else // if defined(__gfx12__)
|
||||
#else
|
||||
return gfx12_t{};
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_words_per_128b() { return 4; }
|
||||
|
||||
namespace detail {
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx9_t) { return 32; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx103_t) { return 32; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx11_t) { return 32; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx12_t) { return 32; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx_invalid_t) { return 0; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto arch_tag_dispatch()
|
||||
{
|
||||
#if defined(__gfx103__)
|
||||
return gfx103_t{};
|
||||
#elif defined(__gfx11__)
|
||||
return gfx11_t{};
|
||||
#elif defined(__gfx12__)
|
||||
return gfx12_t{};
|
||||
#elif defined(__gfx950__)
|
||||
return gfx950_t{};
|
||||
#elif defined(__gfx9__)
|
||||
return gfx9_t{};
|
||||
#else
|
||||
return gfx_invalid_t{};
|
||||
#endif
|
||||
}
|
||||
} // namespace detail
|
||||
CK_TILE_DEVICE static constexpr auto get_n_lds_banks()
|
||||
{
|
||||
return detail::get_n_lds_banks(detail::arch_tag_dispatch());
|
||||
}
|
||||
|
||||
enum LLVMSchedGroupMask : int32_t
|
||||
{
|
||||
NONE = 0,
|
||||
|
||||
@@ -442,7 +442,7 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
|
||||
constexpr index_t SingleVSize = [&]() {
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
|
||||
@@ -140,7 +140,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
|
||||
constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
|
||||
@@ -465,7 +465,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
|
||||
constexpr index_t SingleVSize = [&]() {
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
@@ -620,7 +620,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
|
||||
@@ -71,7 +71,7 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
@@ -94,7 +94,7 @@ struct UniversalGemmBasePolicy
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr auto MLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
|
||||
@@ -141,7 +141,7 @@ struct UniversalGemmBasePolicy
|
||||
* @return B tensor LDS block descriptor.
|
||||
*/
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
@@ -166,7 +166,7 @@ struct UniversalGemmBasePolicy
|
||||
constexpr auto BK0 = number<KPerBlock / KPack>{};
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
@@ -658,25 +658,27 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr auto a_lds_desc = MakeALdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_a = integer_least_multiple(
|
||||
sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16);
|
||||
constexpr index_t smem_size_a =
|
||||
integer_least_multiple(sizeof(typename Problem::ADataType) *
|
||||
Problem::BlockGemmShape::kM * Problem::BlockGemmShape::kK,
|
||||
16);
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr auto b_lds_desc = MakeBLdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16);
|
||||
constexpr index_t smem_size_b =
|
||||
integer_least_multiple(sizeof(typename Problem::BDataType) *
|
||||
Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK,
|
||||
16);
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
|
||||
Reference in New Issue
Block a user