Merge commit 'de6a9590abe907283e189abba1b487f8e5562d1b' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-24 21:29:18 +00:00
parent 5297edb40c
commit 4aaa8c92bb
4 changed files with 52 additions and 17 deletions

View File

@@ -120,7 +120,7 @@ int main(int argc, char* argv[])
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
#else
return !run_gemm_example<GemmConfigComputeV3>(arg_parser);
return !run_gemm_example<GemmConfigComputeV3_2>(arg_parser);
#endif
}
catch(const std::runtime_error& e)

View File

@@ -227,9 +227,15 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA()
{
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
using A = remove_cvref_t<typename Problem::ADataType>;
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
return (KPack < VecElems) ? KPack : VecElems;
}
template <typename Problem>

View File

@@ -97,15 +97,27 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA()
{
return Problem::VectorLoadSize;
using A = remove_cvref_t<typename Problem::ADataType>;
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
return (KPack < VecElems) ? KPack : VecElems;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackB()
{
return Problem::VectorLoadSize;
using B = remove_cvref_t<typename Problem::BDataType>;
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B));
return (KPack < VecElems) ? KPack : VecElems;
}
template <typename Problem>

View File

@@ -232,6 +232,10 @@ struct UniversalGemmBasePolicy
constexpr auto MLdsLayer =
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
constexpr index_t NBanks = get_n_lds_banks();
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
number<MPerBlock / MLdsLayer>{},
@@ -243,7 +247,7 @@ struct UniversalGemmBasePolicy
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer * RowMul>{},
number<KPerBlock / KPack * MLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
@@ -423,6 +427,10 @@ struct UniversalGemmBasePolicy
constexpr auto NLdsLayer =
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
constexpr index_t NBanks = get_n_lds_banks();
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(BK0 * number<NLdsLayer>{},
number<NPerBlock / NLdsLayer>{},
@@ -433,9 +441,10 @@ struct UniversalGemmBasePolicy
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
BK0 * number<NLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(
make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer * RowMul>{},
BK0 * number<NLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
@@ -768,19 +777,27 @@ struct UniversalGemmBasePolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA()
{
using A = remove_cvref_t<typename Problem::ADataType>;
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = BlockGemm::Traits::KPack;
return KPack;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
return (KPack < VecElems) ? KPack : VecElems;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackB()
{
using B = remove_cvref_t<typename Problem::BDataType>;
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = BlockGemm::Traits::KPack;
return KPack;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B));
return (KPack < VecElems) ? KPack : VecElems;
}
template <typename Problem>