diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 6f9d53467f..806a471397 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -161,8 +161,43 @@ struct BatchedGemmKernel } CK_TILE_HOST static auto - IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool + IsSupportedArgument(const typename BatchedGemmKernel::KernelArgs& kargs) -> bool { + if(kargs.batch_count < 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met: batch_count must be at least 1 !"); + } + return false; + } + if(kargs.batch_stride_A < 0 || kargs.batch_stride_A < kargs.M * kargs.K) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conditions not met: batch_stride_A must be non-negative and at least K * M!"); + } + return false; + } + if(kargs.batch_stride_B < 0 || kargs.batch_stride_B < kargs.K * kargs.N) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conditions not met: batch_stride_B must be non-negative and at least K * N!"); + } + return false; + } + if(kargs.batch_stride_E < 0 || kargs.batch_stride_E < kargs.M * kargs.N) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conditions not met: batch_stride_E must be non-negative and at least M * N!"); + } + return false; + } return UniversalGemmKernel::IsSupportedArgument(kargs); }