[CK_TILE] Batched Gemm Kernel IsSupported function checks (#2860)

* Add valid check batched gemm part1

* [CK_TILE] Add batched gemm kernel IsSupported func checks

* revert broken pre-commit hook changes

* revert broken pre-commit hook changes v2

* Clarify error messages
This commit is contained in:
aledudek
2025-10-13 13:55:23 +02:00
committed by GitHub
parent 46c10c316d
commit 3021604213

View File

@@ -161,8 +161,43 @@ struct BatchedGemmKernel
}
CK_TILE_HOST static auto
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
IsSupportedArgument(const typename BatchedGemmKernel::KernelArgs& kargs) -> bool
{
if(kargs.batch_count < 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met: batch_count must be at least 1 !");
}
return false;
}
if(kargs.batch_stride_A < 0 || kargs.batch_stride_A < kargs.M * kargs.K)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Conditions not met: batch_stride_A must be non-negative and at least K * M!");
}
return false;
}
if(kargs.batch_stride_B < 0 || kargs.batch_stride_B < kargs.K * kargs.N)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Conditions not met: batch_stride_B must be non-negative and at least K * N!");
}
return false;
}
if(kargs.batch_stride_E < 0 || kargs.batch_stride_E < kargs.M * kargs.N)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Conditions not met: batch_stride_E must be non-negative and at least M * N!");
}
return false;
}
return UniversalGemmKernel::IsSupportedArgument(kargs);
}