Merge commit 'f7650ee82b306a05d9c3c44d3feefdd570a4bd58' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-05 09:13:29 +00:00
parent eeadb34e8f
commit 86c35117b5

View File

@@ -545,7 +545,7 @@ struct UniversalGemmBasePolicy
}
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA()
{
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
@@ -555,6 +555,11 @@ struct UniversalGemmBasePolicy
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
if constexpr(Problem::FixedVectorSize)
{
return Problem::VectorSizeA;
}
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem,
@@ -574,7 +579,7 @@ struct UniversalGemmBasePolicy
}
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
{
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
@@ -584,6 +589,11 @@ struct UniversalGemmBasePolicy
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
if constexpr(Problem::FixedVectorSize)
{
return Problem::VectorSizeB;
}
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem,