diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 6990fc1496..8620e7337c 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -319,22 +319,67 @@ struct gfx9_t struct gfx950_t { }; +struct gfx103_t +{ +}; struct gfx11_t { }; struct gfx12_t { }; +struct gfx_invalid_t +{ +}; CK_TILE_DEVICE static constexpr auto get_device_arch() { +// FIXME(0): on all devices except gfx11 it returns gfx12_t +// FIXME(1): during the host compilation pass it returns gfx12_t #if defined(__gfx11__) return gfx11_t{}; -#else // if defined(__gfx12__) +#else return gfx12_t{}; #endif } +CK_TILE_DEVICE static constexpr auto get_n_words_per_128b() { return 4; } + +namespace detail { +CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx9_t) { return 32; } + +CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx103_t) { return 32; } + +CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx11_t) { return 32; } + +CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx12_t) { return 32; } + +CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; } + +CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx_invalid_t) { return 0; } + +CK_TILE_DEVICE static constexpr auto arch_tag_dispatch() +{ +#if defined(__gfx103__) + return gfx103_t{}; +#elif defined(__gfx11__) + return gfx11_t{}; +#elif defined(__gfx12__) + return gfx12_t{}; +#elif defined(__gfx950__) + return gfx950_t{}; +#elif defined(__gfx9__) + return gfx9_t{}; +#else + return gfx_invalid_t{}; +#endif +} +} // namespace detail +CK_TILE_DEVICE static constexpr auto get_n_lds_banks() +{ + return detail::get_n_lds_banks(detail::arch_tag_dispatch()); +} + enum LLVMSchedGroupMask : int32_t { NONE = 0, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp index e440280d7e..ac04f54adf 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -442,7 +442,7 @@ struct BlockFmhaV3PipelineDefaultPolicy constexpr index_t SingleVSize = [&]() { using VDataType = remove_cvref_t; - constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t Banks = get_n_lds_banks(); constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); constexpr index_t kKPack = GetSmemKPackK(); static_assert(PixelsPerRow % kKPack == 0); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 014467fe8a..0b9986e083 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -140,7 +140,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers(); - constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t Banks = get_n_lds_banks(); constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); constexpr index_t kKPack = GetSmemKPackV(); static_assert(PixelsPerRow % kKPack == 0); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 7c794a3646..06c6dce6b0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -465,7 +465,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t Banks = get_n_lds_banks(); constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); constexpr index_t kKPack = GetSmemKPackK(); static_assert(PixelsPerRow % kKPack == 0); @@ -620,7 +620,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t Banks = get_n_lds_banks(); constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); constexpr index_t kKPack = GetSmemKPackV(); static_assert(PixelsPerRow % kKPack == 0); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 4030783ecc..ecff6fe497 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -71,7 +71,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor() { using ADataType = remove_cvref_t; @@ -94,7 +94,7 @@ struct UniversalGemmBasePolicy constexpr auto DataTypeSize = sizeof(ADataType); constexpr auto MLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, @@ -141,7 +141,7 @@ struct UniversalGemmBasePolicy * @return B tensor LDS block descriptor. */ template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { using BDataType = remove_cvref_t; @@ -166,7 +166,7 @@ struct UniversalGemmBasePolicy constexpr auto BK0 = number{}; constexpr auto DataTypeSize = sizeof(BDataType); constexpr auto NLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple( @@ -658,25 +658,27 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + CK_TILE_DEVICE static constexpr index_t GetSmemSizeA() { - constexpr auto a_lds_desc = MakeALdsBlockDescriptor(); - constexpr index_t smem_size_a = integer_least_multiple( - sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16); + constexpr index_t smem_size_a = + integer_least_multiple(sizeof(typename Problem::ADataType) * + Problem::BlockGemmShape::kM * Problem::BlockGemmShape::kK, + 16); return smem_size_a; } template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() + CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { - constexpr auto b_lds_desc = MakeBLdsBlockDescriptor(); - constexpr index_t smem_size_b = integer_least_multiple( - sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16); + constexpr index_t smem_size_b = + integer_least_multiple(sizeof(typename Problem::BDataType) * + Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, + 16); return smem_size_b; } template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + CK_TILE_DEVICE static constexpr index_t GetSmemSize() { constexpr index_t smem_size_a = GetSmemSizeA(); constexpr index_t smem_size_b = GetSmemSizeB();