diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 9202da8bd3..5f013b5a94 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -107,32 +107,16 @@ struct is_preshuffleB_enabled> }; } // namespace detail -struct QuantGemmProblem +template +struct QuantGemmMultiDHostArgs { - CK_TILE_HOST QuantGemmProblem() = default; - CK_TILE_HOST QuantGemmProblem(index_t M_, - index_t N_, - index_t K_, - index_t QK_A_, - index_t QK_B_, - index_t stride_A_, - index_t stride_B_, - index_t stride_C_, - index_t stride_AQ_, - index_t stride_BQ_) - : M(M_), - N(N_), - K(K_), - QK_A(QK_A_), - QK_B(QK_B_), - stride_A(stride_A_), - stride_B(stride_B_), - stride_C(stride_C_), - stride_AQ(stride_AQ_), - stride_BQ(stride_BQ_) - { - } - + const void* a_ptr; + const void* b_ptr; + const std::array ds_ptr; + void* c_ptr; + const void* aq_ptr; + const void* bq_ptr; + index_t k_batch; index_t M; index_t N; index_t K; @@ -140,56 +124,20 @@ struct QuantGemmProblem index_t QK_B; index_t stride_A; index_t stride_B; + const std::array stride_Ds; index_t stride_C; index_t stride_AQ; index_t stride_BQ; }; -struct QuantGemmHostArgs : public QuantGemmProblem -{ - CK_TILE_HOST QuantGemmHostArgs() = default; - CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_, - const void* b_ptr_, - void* c_ptr_, - const void* aq_ptr_, - const void* bq_ptr_, - index_t k_batch_, - index_t M_, - index_t N_, - index_t K_, - index_t QK_A_, - index_t QK_B_, - index_t stride_A_, - index_t stride_B_, - index_t stride_C_, - index_t stride_AQ_, - index_t stride_BQ_) - : QuantGemmProblem( - M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_), - a_ptr(a_ptr_), - b_ptr(b_ptr_), - aq_ptr(aq_ptr_), - bq_ptr(bq_ptr_), - c_ptr(c_ptr_), - k_batch(k_batch_) - { - } - - const void* a_ptr = nullptr; - const void* b_ptr = nullptr; - const void* aq_ptr = nullptr; - const void* bq_ptr = nullptr; - void* c_ptr = nullptr; - // k_batch must be a positive integer; defaults to 1 (no split-K). - index_t k_batch = 1; -}; - -struct QuantGemmKernelArgs +template +struct QuantGemmMultiDKernelArgs { const void* a_ptr; const void* b_ptr; const void* aq_ptr; const void* bq_ptr; + const std::array ds_ptr; void* c_ptr; index_t M; index_t N; @@ -198,6 +146,7 @@ struct QuantGemmKernelArgs index_t QK_B; index_t stride_A; index_t stride_B; + const std::array stride_Ds; index_t stride_C; index_t stride_AQ; index_t stride_BQ; @@ -233,7 +182,7 @@ template -struct QuantGemmKernel +struct QuantGemmMultiDKernel { using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; @@ -241,6 +190,7 @@ struct QuantGemmKernel using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; using AQLayout = remove_cvref_t< typename detail::get_aq_layout_or::type>; @@ -257,6 +207,7 @@ struct QuantGemmKernel using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using AQDataType = @@ -264,15 +215,23 @@ struct QuantGemmKernel using BQDataType = remove_cvref_t::type>; - static constexpr auto I0 = number<0>(); // A Tensor - static constexpr auto I1 = number<1>(); // AQ Tensor - static constexpr auto I2 = number<2>(); // B Tensor - static constexpr auto I3 = number<3>(); // BQ Tensor - static constexpr auto I4 = number<4>(); // C Tensor + static_assert(is_detected::value && + is_detected::value && + DsLayout::size() == DsDataType::size(), + "DsLayout and DsDataType must be tuples and must have the same size."); + + static constexpr index_t NumDTensor = DsDataType::size(); + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); static constexpr auto kQuantType = QuantType_; static constexpr bool RuntimeSplitKTail = RuntimeSplitKTail_; + using HostArgs = QuantGemmMultiDHostArgs; + using KernelArgs = QuantGemmMultiDKernelArgs; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off @@ -290,25 +249,26 @@ struct QuantGemmKernel return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize); } - CK_TILE_HOST static constexpr QuantGemmKernelArgs - MakeKernelArgs(const QuantGemmHostArgs& hostArgs) + CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const HostArgs& hostArgs) { - return QuantGemmKernelArgs{hostArgs.a_ptr, - hostArgs.b_ptr, - hostArgs.aq_ptr, - hostArgs.bq_ptr, - hostArgs.c_ptr, - hostArgs.M, - hostArgs.N, - hostArgs.K, - hostArgs.QK_A, - hostArgs.QK_B, - hostArgs.stride_A, - hostArgs.stride_B, - hostArgs.stride_C, - hostArgs.stride_AQ, - hostArgs.stride_BQ, - hostArgs.k_batch}; + return KernelArgs{hostArgs.a_ptr, + hostArgs.b_ptr, + hostArgs.aq_ptr, + hostArgs.bq_ptr, + hostArgs.ds_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.QK_A, + hostArgs.QK_B, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_Ds, + hostArgs.stride_C, + hostArgs.stride_AQ, + hostArgs.stride_BQ, + hostArgs.k_batch}; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -408,8 +368,8 @@ struct QuantGemmKernel public: struct SplitKBatchOffset { - __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs, - const std::size_t k_id = blockIdx.z) + CK_TILE_DEVICE SplitKBatchOffset(const KernelArgs& kargs, + const std::size_t k_id = blockIdx.z) { constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); // smallest unit of K work per block @@ -543,7 +503,7 @@ struct QuantGemmKernel }; CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, - const QuantGemmKernelArgs& kargs, + const KernelArgs& kargs, const index_t k_size, const index_t i_m) { @@ -609,7 +569,7 @@ struct QuantGemmKernel } CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr, - const QuantGemmKernelArgs& kargs, + const KernelArgs& kargs, const index_t i_m, const index_t i_n, const index_t aq_group_offset = 0) @@ -796,7 +756,7 @@ struct QuantGemmKernel } CK_TILE_DEVICE static auto MakeBBlockWindow(const BDataType* b_ptr, - const QuantGemmKernelArgs& kargs, + const KernelArgs& kargs, const index_t k_size, const index_t i_n) { @@ -935,7 +895,7 @@ struct QuantGemmKernel } CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr, - const QuantGemmKernelArgs& kargs, + const KernelArgs& kargs, const index_t bq_group_offset, const index_t i_m, const index_t i_n) @@ -1125,9 +1085,101 @@ struct QuantGemmKernel return bq_block_window; } + template + CK_TILE_DEVICE static auto + MakeDTensorDescriptor(const index_t M, const index_t N, const index_t stride) + { + if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(stride, 1), number{}, number<1>{}); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, M), make_tuple(stride, 1), number{}, number<1>{}); + } + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const DsTensorDesc& ds_desc, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor views + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DDataType_ = remove_cvref_t>; + return make_tensor_view( + static_cast(ds_ptr[i]), ds_desc[i]); + }, + number{}); + + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows + const auto& ds_block_window = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, i_m}); + } + }, + number{}); + + return ds_block_window; + } + + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + const auto& ds_tensor_desc = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + return MakeDTensorDescriptor( + kargs.M, kargs.N, kargs.stride_Ds[i]); + }, + number{}); + + return MakeDBlockWindows(ds_ptr, ds_tensor_desc, i_m, i_n); + } + template CK_TILE_DEVICE static auto MakeCBlockWindow(CDataType* c_ptr, - const QuantGemmKernelArgs& kargs, + const KernelArgs& kargs, const index_t i_m, const index_t i_n) { @@ -1180,7 +1232,7 @@ struct QuantGemmKernel return c_block_window; } - CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs) + CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) { // k_batch must be a positive integer. if(kargs.k_batch <= 0) @@ -1438,6 +1490,59 @@ struct QuantGemmKernel } } + bool DTensorIsValid = {true}; + static_for<0, NumDTensor, 1>{}([&](auto index) { + using DiLayout = remove_cvref_t>; + if(std::is_same_v == false) + { + DTensorIsValid = false; + } + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " + "NPerBlock without padding!"); + } + DTensorIsValid = false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); + } + DTensorIsValid = false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " + "MPerBlock without padding!"); + } + DTensorIsValid = false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); + } + DTensorIsValid = false; + } + } + }); + if(!DTensorIsValid) + { + return false; + } + if constexpr(std::is_same_v) { // For RowMajor C, M is the row dimension - check M alignment here because @@ -1585,9 +1690,10 @@ struct QuantGemmKernel const BDataType* b_ptr, const AQDataType* aq_ptr, const BQDataType* bq_ptr, + const std::array& ds_ptr, CDataType* c_ptr, void* smem_ptr, - const QuantGemmKernelArgs& kargs, + const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -1605,6 +1711,7 @@ struct QuantGemmKernel // the remaining K-groups from the split-K offset position. const auto& bq_block_window = MakeBQBlockWindow( bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); @@ -1667,13 +1774,13 @@ struct QuantGemmKernel kQuantType == QuantType::AQuantGrouped || kQuantType == QuantType::BQuantGrouped) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); } else if constexpr(kQuantType == QuantType::RowColQuant) { EpiloguePipeline{}(c_block_window, c_block_tile, - c_block_window, + ds_block_window, smem_ptr, aq_block_window, bq_block_window); @@ -1683,7 +1790,7 @@ struct QuantGemmKernel const AccDataType aq_scale = type_convert(*aq_ptr); const AccDataType bq_scale = type_convert(*bq_ptr); EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); + c_block_window, c_block_tile, ds_block_window, smem_ptr, aq_scale, bq_scale); } } else @@ -1695,13 +1802,13 @@ struct QuantGemmKernel kQuantType == QuantType::AQuantGrouped || kQuantType == QuantType::BQuantGrouped) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); } else if constexpr(kQuantType == QuantType::RowColQuant) { EpiloguePipeline{}(c_block_window, c_block_tile, - c_block_window, + ds_block_window, smem_ptr, aq_block_window, bq_block_window); @@ -1711,12 +1818,12 @@ struct QuantGemmKernel const AccDataType aq_scale = type_convert(*aq_ptr); const AccDataType bq_scale = type_convert(*bq_ptr); EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); + c_block_window, c_block_tile, ds_block_window, smem_ptr, aq_scale, bq_scale); } } } - CK_TILE_DEVICE void Run_(const QuantGemmKernelArgs& kargs) const + CK_TILE_DEVICE void Run_(const KernelArgs& kargs) const { const auto blockId = amd_wave_read_first_lane(blockIdx.x); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); @@ -1739,8 +1846,17 @@ struct QuantGemmKernel // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - RunGemm( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + RunGemm(a_ptr, + b_ptr, + aq_ptr, + bq_ptr, + kargs.ds_ptr, + c_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + i_m, + i_n); } template @@ -1749,7 +1865,7 @@ struct QuantGemmKernel static constexpr bool kIsAvailableV> = T::kIsAvailable; - CK_TILE_DEVICE void operator()(const QuantGemmKernelArgs& kargs) const + CK_TILE_DEVICE void operator()(const KernelArgs& kargs) const { if constexpr(!kIsAvailableV) ignore = kargs; @@ -1758,6 +1874,99 @@ struct QuantGemmKernel } }; +struct QuantGemmHostArgs : public QuantGemmMultiDHostArgs<0> +{ + CK_TILE_HOST QuantGemmHostArgs() = default; + CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + const void* aq_ptr_, + const void* bq_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t QK_A_, + index_t QK_B_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_, + index_t stride_AQ_, + index_t stride_BQ_) + : QuantGemmMultiDHostArgs{a_ptr_, + b_ptr_, + std::array{}, + c_ptr_, + aq_ptr_, + bq_ptr_, + k_batch_, + M_, + N_, + K_, + QK_A_, + QK_B_, + stride_A_, + stride_B_, + std::array{}, + stride_C_, + stride_AQ_, + stride_BQ_} + { + } +}; + +struct QuantGemmKernelArgs : public QuantGemmMultiDKernelArgs<0> +{ + CK_TILE_HOST QuantGemmKernelArgs() = default; + CK_TILE_HOST QuantGemmKernelArgs(const void* a_ptr_, + const void* b_ptr_, + const void* aq_ptr_, + const void* bq_ptr_, + void* c_ptr_, + index_t M_, + index_t N_, + index_t K_, + index_t QK_A_, + index_t QK_B_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_, + index_t stride_AQ_, + index_t stride_BQ_, + index_t k_batch_) + : QuantGemmMultiDKernelArgs<0>{a_ptr_, + b_ptr_, + aq_ptr_, + bq_ptr_, + std::array{}, + c_ptr_, + M_, + N_, + K_, + QK_A_, + QK_B_, + stride_A_, + stride_B_, + std::array{}, + stride_C_, + stride_AQ_, + stride_BQ_, + k_batch_} + { + } +}; + +template +using QuantGemmKernel = QuantGemmMultiDKernel; + } // namespace ck_tile #if __clang_major__ >= 23 #pragma clang diagnostic pop diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index be9d94d8b8..434be513b8 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -366,6 +366,7 @@ struct QuantGroupedGemmKernel {b_ptr}, aq_ptr, bq_ptr, + {}, c_ptr, smem_ptr, kargs,