fix enforcing fixedvectorsizes for ck tile conv (#3344)

This commit is contained in:
jakpiase
2025-12-05 09:30:22 +01:00
committed by GitHub
parent 13f6d63565
commit f7650ee82b

View File

@@ -545,7 +545,7 @@ struct UniversalGemmBasePolicy
}
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA()
{
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
@@ -555,6 +555,11 @@ struct UniversalGemmBasePolicy
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
if constexpr(Problem::FixedVectorSize)
{
return Problem::VectorSizeA;
}
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem,
@@ -574,7 +579,7 @@ struct UniversalGemmBasePolicy
}
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
{
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
@@ -584,6 +589,11 @@ struct UniversalGemmBasePolicy
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
if constexpr(Problem::FixedVectorSize)
{
return Problem::VectorSizeB;
}
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem,