diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index a8a7288a3d..f392d2e22e 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -120,7 +120,7 @@ int main(int argc, char* argv[]) #if CK_TILE_USE_WMMA return !run_gemm_example(arg_parser); #else - return !run_gemm_example(arg_parser); + return !run_gemm_example(arg_parser); #endif } catch(const std::runtime_error& e) diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp old mode 100644 new mode 100755 index 896f6613a7..31c080c520 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -227,9 +227,15 @@ struct UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA() { - return Problem::VectorLoadSize / sizeof(typename Problem::ADataType); + using A = remove_cvref_t; + using BlockGemm = remove_cvref_t())>; + + constexpr index_t KPack = static_cast(BlockGemm::Traits::KPack); + constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(A)); + + return (KPack < VecElems) ? KPack : VecElems; } template diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp old mode 100644 new mode 100755 index c8f4cfd4ec..712a6b6ac3 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -97,15 +97,27 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA() { - return Problem::VectorLoadSize; + using A = remove_cvref_t; + using BlockGemm = remove_cvref_t())>; + + constexpr index_t KPack = static_cast(BlockGemm::Traits::KPack); + constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(A)); + + return (KPack < VecElems) ? KPack : VecElems; } template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackB() { - return Problem::VectorLoadSize; + using B = remove_cvref_t; + using BlockGemm = remove_cvref_t())>; + + constexpr index_t KPack = static_cast(BlockGemm::Traits::KPack); + constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(B)); + + return (KPack < VecElems) ? KPack : VecElems; } template diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index fd95958995..6cca15c1d8 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -232,6 +232,10 @@ struct UniversalGemmBasePolicy constexpr auto MLdsLayer = max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); + constexpr index_t NBanks = get_n_lds_banks(); + static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count"); + constexpr index_t RowMul = (NBanks == 64) ? 2 : 1; + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, @@ -243,7 +247,7 @@ struct UniversalGemmBasePolicy constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( a_lds_block_desc_0, make_tuple( - make_xor_transform(make_tuple(number{}, + make_xor_transform(make_tuple(number{}, number{})), make_pass_through_transform(number{})), make_tuple(sequence<1, 0>{}, sequence<2>{}), @@ -423,6 +427,10 @@ struct UniversalGemmBasePolicy constexpr auto NLdsLayer = max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); + constexpr index_t NBanks = get_n_lds_banks(); + static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count"); + constexpr index_t RowMul = (NBanks == 64) ? 2 : 1; + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(BK0 * number{}, number{}, @@ -433,9 +441,10 @@ struct UniversalGemmBasePolicy constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - BK0 * number{})), - make_pass_through_transform(number{})), + make_tuple( + make_xor_transform(make_tuple(number{}, + BK0 * number{})), + make_pass_through_transform(number{})), make_tuple(sequence<1, 0>{}, sequence<2>{}), make_tuple(sequence<1, 0>{}, sequence<2>{})); @@ -768,19 +777,27 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA() { + using A = remove_cvref_t; using BlockGemm = remove_cvref_t())>; - constexpr index_t KPack = BlockGemm::Traits::KPack; - return KPack; + + constexpr index_t KPack = static_cast(BlockGemm::Traits::KPack); + constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(A)); + + return (KPack < VecElems) ? KPack : VecElems; } template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackB() { + using B = remove_cvref_t; using BlockGemm = remove_cvref_t())>; - constexpr index_t KPack = BlockGemm::Traits::KPack; - return KPack; + + constexpr index_t KPack = static_cast(BlockGemm::Traits::KPack); + constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(B)); + + return (KPack < VecElems) ? KPack : VecElems; } template