From f1c8acbd71e835e4bacd6c421f869e23320a1606 Mon Sep 17 00:00:00 2001 From: aledudek Date: Mon, 13 Oct 2025 13:55:23 +0200 Subject: [PATCH] [CK_TILE] Batched Gemm Kernel IsSupported function checks (#2860) * Add valid check batched gemm part1 * [CK_TILE] Add batched gemm kernel IsSupported func checks * revert broken pre-commit hook changes * revert broken pre-commit hook changes v2 * Clarify error messages [ROCm/composable_kernel commit: 3021604213750fc5acb02dad50e60ea8b0176b91] --- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 6f9d53467f..806a471397 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -161,8 +161,43 @@ struct BatchedGemmKernel } CK_TILE_HOST static auto - IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool + IsSupportedArgument(const typename BatchedGemmKernel::KernelArgs& kargs) -> bool { + if(kargs.batch_count < 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met: batch_count must be at least 1 !"); + } + return false; + } + if(kargs.batch_stride_A < 0 || kargs.batch_stride_A < kargs.M * kargs.K) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conditions not met: batch_stride_A must be non-negative and at least K * M!"); + } + return false; + } + if(kargs.batch_stride_B < 0 || kargs.batch_stride_B < kargs.K * kargs.N) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conditions not met: batch_stride_B must be non-negative and at least K * N!"); + } + return false; + } + if(kargs.batch_stride_E < 0 || kargs.batch_stride_E < kargs.M * kargs.N) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conditions not met: batch_stride_E must be non-negative and at least M * N!"); + } + return false; + } return UniversalGemmKernel::IsSupportedArgument(kargs); }