mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
fix enforcing fixedvectorsizes for ck tile conv (#3344)
This commit is contained in:
@@ -545,7 +545,7 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
@@ -555,6 +555,11 @@ struct UniversalGemmBasePolicy
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
|
||||
if constexpr(Problem::FixedVectorSize)
|
||||
{
|
||||
return Problem::VectorSizeA;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
@@ -574,7 +579,7 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
@@ -584,6 +589,11 @@ struct UniversalGemmBasePolicy
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
|
||||
if constexpr(Problem::FixedVectorSize)
|
||||
{
|
||||
return Problem::VectorSizeB;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
|
||||
Reference in New Issue
Block a user