From 016da0d5f032f37b356bf71cb7d006de38d64c80 Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Fri, 19 Jun 2026 16:27:02 +0500 Subject: [PATCH] Support large tensors in quant gemm kernel Implemented only for large M, RCR layout with RowCol quantization. --- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 5f013b5a94..711f65c10e 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -1823,11 +1823,11 @@ struct QuantGemmMultiDKernel } } - CK_TILE_DEVICE void Run_(const KernelArgs& kargs) const + CK_TILE_DEVICE void Run_(KernelArgs kargs) const { const auto blockId = amd_wave_read_first_lane(blockIdx.x); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); - const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); @@ -1844,13 +1844,42 @@ struct QuantGemmMultiDKernel static_cast(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset; CDataType* c_ptr = static_cast(kargs.c_ptr); + std::array ds_ptr = kargs.ds_ptr; + + // Large tensor support (when M is large, N and K are relatively small) + constexpr bool offset_ptrs_by_tile_coords = [] { + bool suitable = kQuantType == QuantType::RowColQuant; + suitable = suitable && std::is_same_v; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DiLayout = remove_cvref_t>; + suitable = suitable && std::is_same_v; + }); + suitable = suitable && std::is_same_v; + return suitable; + }; + if constexpr(offset_ptrs_by_tile_coords) + { + a_ptr += static_cast(i_m) * kargs.stride_A; + aq_ptr += i_m; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + ds_ptr[i] = + static_cast(ds_ptr[i]) + + sizeof(DDataType_) * static_cast(i_m) * kargs.stride_Ds[i]; + }); + c_ptr += static_cast(i_m) * kargs.stride_C; + + kargs.M = std::min(kargs.M - i_m, TilePartitioner::MPerBlock); + i_m = 0; + } + // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; RunGemm(a_ptr, b_ptr, aq_ptr, bq_ptr, - kargs.ds_ptr, + ds_ptr, c_ptr, smem_ptr, kargs,