Merge some updates for ck_tile headers (#3342)

* fix some issues from internal branch

* update cshuffle_epilogue

* update cshuffle_epilogue

* update cshuffle

* update warp_gemm
This commit is contained in:
joyeamd
2026-01-06 15:39:00 +08:00
committed by GitHub
parent 2b563ad048
commit b78563b3d3
14 changed files with 205 additions and 119 deletions

View File

@@ -423,7 +423,7 @@ struct UniversalGemmKernel
const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
: GemmPipeline::template GetVectorSizeA<false>();
bool AsTesnorIsValid = {true};
bool AsTensorIsValid = {true};
static_for<0, NumATensor, 1>{}([&](auto index) {
using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
@@ -437,15 +437,27 @@ struct UniversalGemmKernel
"Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
AsTesnorIsValid = false;
AsTensorIsValid = false;
}
if(kargs.K % vectorSizeA != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
const auto remainder = kargs.K % vectorSizeA;
constexpr ck_tile::index_t APackedSize =
ck_tile::numeric_traits<ADataType>::PackedSize;
const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize;
// oob can support to dword level
if(remainder_in_bytes % 4 == 0)
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
AsTensorIsValid = true;
}
else
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
}
AsTensorIsValid = false;
}
AsTesnorIsValid = false;
}
}
else
@@ -457,20 +469,33 @@ struct UniversalGemmKernel
CK_TILE_ERROR(
"Can't support M that is not a multiple of MPerBlock without padding!");
}
AsTesnorIsValid = false;
AsTensorIsValid = false;
}
if(kargs.M % vectorSizeA != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
const auto remainder = kargs.M % vectorSizeA;
constexpr ck_tile::index_t APackedSize =
ck_tile::numeric_traits<ADataType>::PackedSize;
const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize;
// oob can support to dword level
if(remainder_in_bytes % 4 == 0)
{
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
AsTensorIsValid = true;
}
else
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
}
AsTensorIsValid = false;
}
AsTesnorIsValid = false;
}
}
});
bool BsTesnorIsValid = {true};
bool BsTensorIsValid = {true};
const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
: GemmPipeline::template GetVectorSizeB<false>();
static_for<0, NumBTensor, 1>{}([&](auto index) {
@@ -484,47 +509,72 @@ struct UniversalGemmKernel
CK_TILE_ERROR(
"Can't support N that is not a multiple of NPerBlock without padding!");
}
BsTesnorIsValid = false;
BsTensorIsValid = false;
}
if(kargs.N % vectorSizeB != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
const auto remainder = kargs.N % vectorSizeB;
constexpr ck_tile::index_t BPackedSize =
ck_tile::numeric_traits<BDataType>::PackedSize;
const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize;
// oob can support to dword level
if(remainder_in_bytes % 4 == 0)
{
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
BsTensorIsValid = true;
}
else
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
}
BsTensorIsValid = false;
}
BsTesnorIsValid = false;
}
}
else
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
else
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
{
CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
BsTensorIsValid = false;
}
BsTesnorIsValid = false;
}
if(kargs.K % vectorSizeB != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
if(kargs.K % vectorSizeB != 0)
{
CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
const auto remainder = kargs.K % vectorSizeB;
constexpr ck_tile::index_t BPackedSize =
ck_tile::numeric_traits<BDataType>::PackedSize;
const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize;
// oob can support to dword level
if(remainder_in_bytes % 4 == 0)
{
BsTensorIsValid = true;
}
else
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"K is not a multiple of vector load size for B tensor!");
}
BsTensorIsValid = false;
}
}
BsTesnorIsValid = false;
}
}
});
bool DTesnorIsValid = {true};
bool DTensorIsValid = {true};
static_for<0, NumDTensor, 1>{}([&](auto index) {
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if(std::is_same_v<DiLayout, CLayout> == false)
{
DTesnorIsValid = false;
DTensorIsValid = false;
}
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
@@ -535,7 +585,7 @@ struct UniversalGemmKernel
CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
"NPerBlock without padding!");
}
DTesnorIsValid = false;
DTensorIsValid = false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
{
@@ -543,7 +593,7 @@ struct UniversalGemmKernel
{
CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
}
DTesnorIsValid = false;
DTensorIsValid = false;
}
}
else
@@ -555,7 +605,7 @@ struct UniversalGemmKernel
CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
"MPerBlock without padding!");
}
DTesnorIsValid = false;
DTensorIsValid = false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
{
@@ -563,7 +613,7 @@ struct UniversalGemmKernel
{
CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
}
DTesnorIsValid = false;
DTensorIsValid = false;
}
}
});
@@ -608,7 +658,7 @@ struct UniversalGemmKernel
return false;
}
}
return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
}
CK_TILE_DEVICE static auto