Support large tensors in quant gemm kernel

Implemented only for large M, RCR layout with RowCol quantization.
This commit is contained in:
Anton Gorenko
2026-06-19 16:27:02 +05:00
parent ed8c9dd4f2
commit 016da0d5f0

View File

@@ -1823,11 +1823,11 @@ struct QuantGemmMultiDKernel
}
}
CK_TILE_DEVICE void Run_(const KernelArgs& kargs) const
CK_TILE_DEVICE void Run_(KernelArgs kargs) const
{
const auto blockId = amd_wave_read_first_lane(blockIdx.x);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
@@ -1844,13 +1844,42 @@ struct QuantGemmMultiDKernel
static_cast<const BQDataType*>(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset;
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
std::array<const void*, NumDTensor> ds_ptr = kargs.ds_ptr;
// Large tensor support (when M is large, N and K are relatively small)
constexpr bool offset_ptrs_by_tile_coords = [] {
bool suitable = kQuantType == QuantType::RowColQuant;
suitable = suitable && std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
suitable = suitable && std::is_same_v<tensor_layout::gemm::RowMajor, DiLayout>;
});
suitable = suitable && std::is_same_v<tensor_layout::gemm::RowMajor, CLayout>;
return suitable;
};
if constexpr(offset_ptrs_by_tile_coords)
{
a_ptr += static_cast<std::ptrdiff_t>(i_m) * kargs.stride_A;
aq_ptr += i_m;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
ds_ptr[i] =
static_cast<const char*>(ds_ptr[i]) +
sizeof(DDataType_) * static_cast<std::ptrdiff_t>(i_m) * kargs.stride_Ds[i];
});
c_ptr += static_cast<std::ptrdiff_t>(i_m) * kargs.stride_C;
kargs.M = std::min(kargs.M - i_m, TilePartitioner::MPerBlock);
i_m = 0;
}
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
RunGemm(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
kargs.ds_ptr,
ds_ptr,
c_ptr,
smem_ptr,
kargs,