mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Batched Gemm Kernel IsSupported function checks (#2860)
* Add valid check batched gemm part1 * [CK_TILE] Add batched gemm kernel IsSupported func checks * revert broken pre-commit hook changes * revert broken pre-commit hook changes v2 * Clarify error messages
This commit is contained in:
@@ -161,8 +161,43 @@ struct BatchedGemmKernel
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
|
||||
IsSupportedArgument(const typename BatchedGemmKernel::KernelArgs& kargs) -> bool
|
||||
{
|
||||
if(kargs.batch_count < 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met: batch_count must be at least 1 !");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.batch_stride_A < 0 || kargs.batch_stride_A < kargs.M * kargs.K)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Conditions not met: batch_stride_A must be non-negative and at least K * M!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.batch_stride_B < 0 || kargs.batch_stride_B < kargs.K * kargs.N)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Conditions not met: batch_stride_B must be non-negative and at least K * N!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.batch_stride_E < 0 || kargs.batch_stride_E < kargs.M * kargs.N)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Conditions not met: batch_stride_E must be non-negative and at least M * N!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user