mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Support large tensors in quant gemm kernel
Implemented only for large M, RCR layout with RowCol quantization.
This commit is contained in:
@@ -1823,11 +1823,11 @@ struct QuantGemmMultiDKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void Run_(const KernelArgs& kargs) const
|
||||
CK_TILE_DEVICE void Run_(KernelArgs kargs) const
|
||||
{
|
||||
const auto blockId = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
|
||||
@@ -1844,13 +1844,42 @@ struct QuantGemmMultiDKernel
|
||||
static_cast<const BQDataType*>(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
std::array<const void*, NumDTensor> ds_ptr = kargs.ds_ptr;
|
||||
|
||||
// Large tensor support (when M is large, N and K are relatively small)
|
||||
constexpr bool offset_ptrs_by_tile_coords = [] {
|
||||
bool suitable = kQuantType == QuantType::RowColQuant;
|
||||
suitable = suitable && std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
suitable = suitable && std::is_same_v<tensor_layout::gemm::RowMajor, DiLayout>;
|
||||
});
|
||||
suitable = suitable && std::is_same_v<tensor_layout::gemm::RowMajor, CLayout>;
|
||||
return suitable;
|
||||
};
|
||||
if constexpr(offset_ptrs_by_tile_coords)
|
||||
{
|
||||
a_ptr += static_cast<std::ptrdiff_t>(i_m) * kargs.stride_A;
|
||||
aq_ptr += i_m;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
ds_ptr[i] =
|
||||
static_cast<const char*>(ds_ptr[i]) +
|
||||
sizeof(DDataType_) * static_cast<std::ptrdiff_t>(i_m) * kargs.stride_Ds[i];
|
||||
});
|
||||
c_ptr += static_cast<std::ptrdiff_t>(i_m) * kargs.stride_C;
|
||||
|
||||
kargs.M = std::min(kargs.M - i_m, TilePartitioner::MPerBlock);
|
||||
i_m = 0;
|
||||
}
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
kargs.ds_ptr,
|
||||
ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
|
||||
Reference in New Issue
Block a user