mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 18:56:59 +00:00
[CK_TILE] [QuantGEMM] Fix SplitK tail handling and other improvements (#7199) This pull request introduces improved and more robust split-K support for quantized GEMM. The main changes add runtime validation, utility functions for split-K batch calculations, pointer offset handling for split-K in grouped kernels, and enhanced support for various tensor layouts. The changes also improve error handling and provide more flexibility for runtime tail handling in split-K pipelines. **Split-K Support and Validation Enhancements:** * Added runtime validation to ensure `k_batch` is a positive integer and that split-K configurations do not produce empty final batches or mismatched pipeline tails, with detailed error messages and logging for misconfiguration. [[1]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871R1184-R1211) [[2]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L1161-R1250) * Introduced utility functions `get_splitk_batch_k_read` and `get_splitk_last_batch_k` to compute per-batch K read sizes and handle split rounding, ensuring correct and consistent split-K batch partitioning. [[1]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871R206-R234) [[2]](diffhunk://#diff-635b89bdffa96b2b42f1632520cde36701d7d631e864185591f6b32f7645cf47L104-R107) [[3]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L388-R417) [[4]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L1161-R1250) * Changed the default value of `k_batch` in `QuantGemmHostArgs` to 1 (no split-K) for safer default behavior. **Pointer Offsets and Grouped Kernel Handling:** * Updated `QuantGroupedGemmKernel` to apply split-K per-batch offsets to all input pointers, mirroring the behavior of non-grouped kernels and ensuring correctness for split-K launches. * Modified AQ tensor view handling to correctly reflect the remaining K-groups from the split-K batch's offset position, improving accuracy for split-K in grouped kernels. **Pipeline and Layout Flexibility:** * Added support for runtime selection of split-K tail handling via a new template parameter `RuntimeSplitKTail_`, with new helper methods to dispatch GEMM pipelines accordingly. [[1]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871R273) [[2]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871R1496-R1567) [[3]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L1427) [[4]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L1447-R1629) [[5]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L1459-R1641) * Improved handling for tensor layout cases, including preshuffled B and both row-major and column-major AQ layouts, ensuring correct pointer arithmetic and compatibility checks. [[1]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871R438-R454) [[2]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871L464-R516) [[3]](diffhunk://#diff-d000149a681cd42bfb9947872c603e556cea26cbd7fd4f8f60afc6595d975871R1184-R1211)