mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Merge commit 'de6a9590abe907283e189abba1b487f8e5562d1b' into develop
This commit is contained in:
10
include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp
Normal file → Executable file
10
include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp
Normal file → Executable file
@@ -227,9 +227,15 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA()
|
||||
{
|
||||
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
|
||||
using A = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
|
||||
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
20
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
Normal file → Executable file
20
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
Normal file → Executable file
@@ -97,15 +97,27 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA()
|
||||
{
|
||||
return Problem::VectorLoadSize;
|
||||
using A = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
|
||||
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackB()
|
||||
{
|
||||
return Problem::VectorLoadSize;
|
||||
using B = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B));
|
||||
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -232,6 +232,10 @@ struct UniversalGemmBasePolicy
|
||||
constexpr auto MLdsLayer =
|
||||
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
|
||||
number<MPerBlock / MLdsLayer>{},
|
||||
@@ -243,7 +247,7 @@ struct UniversalGemmBasePolicy
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
|
||||
make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer * RowMul>{},
|
||||
number<KPerBlock / KPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
@@ -423,6 +427,10 @@ struct UniversalGemmBasePolicy
|
||||
constexpr auto NLdsLayer =
|
||||
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(BK0 * number<NLdsLayer>{},
|
||||
number<NPerBlock / NLdsLayer>{},
|
||||
@@ -433,9 +441,10 @@ struct UniversalGemmBasePolicy
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
|
||||
BK0 * number<NLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer * RowMul>{},
|
||||
BK0 * number<NLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
@@ -768,19 +777,27 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA()
|
||||
{
|
||||
using A = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
|
||||
constexpr index_t KPack = BlockGemm::Traits::KPack;
|
||||
return KPack;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
|
||||
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackB()
|
||||
{
|
||||
using B = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
|
||||
constexpr index_t KPack = BlockGemm::Traits::KPack;
|
||||
return KPack;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B));
|
||||
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user