[CK_Tile] Fix gemm kernel for 4,64,16 and 64,4,16 warp tile sizes (#2262)

* debugging issue

* debugging issue

* debugging

* debugging

* reverting debugging code

* clang formatted

* updating default_config.json

* fix ci failure

* clang formatted
This commit is contained in:
Khushbu Agarwal
2025-06-03 20:16:10 -07:00
committed by GitHub
parent 1037b21cfe
commit 59a85cb4bc
6 changed files with 46 additions and 17 deletions

View File

@@ -92,7 +92,20 @@ struct CShuffleEpilogue
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
constexpr index_t MaxVectorStoreSize = 16;
return MaxVectorStoreSize / sizeof(ODataType);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(kNPerIteration),
static_cast<int>(MaxVectorStoreSize / sizeof(ODataType)));
}
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(kMPerIteration),
static_cast<int>(MaxVectorStoreSize / sizeof(ODataType)));
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
template <typename Problem>