ck_tile kernel for gemm with groupwise quantized B tensor. (#2663)

* This change introduces new pipelines with Intrawave scheduler and block gemm primitives that loads the scale tensor to registers to perform dequantization post MFMA on C tensor in registers.

Scale tensor data, BQ is spliced across threads in registers and not stored in LDS.

Current support is for the following combinations, but it should be fairly straightforward to extend support to more formats.

fp8, fp8 -> f32
bf8, bf8 -> f32
fp8, i4 -> f32
bf8, i4 -> f32
Group size can go down to as low as K length of underlying WarpGemm primitive.

* Solve merge conflict

* [CK TILE] Update CHANGELOG.md

---------

Co-authored-by: Vijay Krishnamoorthy <vjkrish@fb.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Cong Ma <congma13@amd.com>
This commit is contained in:
Vijay Krish
2025-08-28 23:43:02 -07:00
committed by GitHub
parent 428090f749
commit 4208e28988
20 changed files with 2471 additions and 26 deletions

View File

@@ -12,7 +12,7 @@
namespace ck_tile {
template <typename Problem, index_t UnaryOpSize_ = 8>
struct BlockGemmQuantBase
struct BlockGemmAQuantBase
{
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
@@ -67,7 +67,7 @@ struct BlockGemmQuantBase
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
{
private:
template <typename PipelineProblem_, typename GemmPolicy_>
@@ -103,13 +103,13 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
using I1 = number<1>;
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
"Error! WarpGemm's MWarp is not consistent with BlockGemmShape!");
static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
"Error! WarpGemm's NWarp is not consistent with BlockGemmShape!");
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
"Error! WarpGemm's M is not consisten with BlockGemmShape!");
"Error! WarpGemm's M is not consistent with BlockGemmShape!");
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
"Error! WarpGemm's N is not consisten with BlockGemmShape!");
"Error! WarpGemm's N is not consistent with BlockGemmShape!");
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
@@ -170,7 +170,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using Base = BlockGemmQuantBase<Problem_>;
using Base = BlockGemmAQuantBase<Problem_>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
@@ -323,7 +323,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
[[maybe_unused]] BSmemBlockWindow& b_block_window)
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"The CDataType as defined in traits should be the same as corresponding "
"C block tensor data type!");
constexpr auto warp_size = get_warp_size();

View File

@@ -0,0 +1,439 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
namespace ck_tile {
template <typename Problem, index_t UnaryOpSize_ = 8>
struct BlockGemmBQuantBase
{
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
static constexpr index_t UnaryOpSize = UnaryOpSize_;
template <typename T>
CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale)
{
float scale_reg_f = 0.f;
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
{
scale_reg_f =
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
}
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
{
scale_reg_f =
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
}
else if constexpr(std::is_same_v<BQDataType, float>)
{
scale_reg_f = ck_tile::bit_cast<float>(scale);
}
else
{
static_assert(false, "BQDataType must be float, fp8_t or bf8_t.");
}
return scale_reg_f;
}
// can be inherited from A
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
const WarpWindow& warp_window)
{
const element_wise::PassThroughPack8 elementwise_op{};
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
};
// A is block window on shared memory
// BQ (scale tensor) is block distributed tensor.
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
{
private:
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
// Threadblock GEMM tile size
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t BQPerBlock = KPerBlock / kQuantGroupSize;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
// number of warps along M and N for threadblock's GEMM problem size
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
using I0 = number<0>;
using I1 = number<1>;
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
"Error! WarpGemm's MWarp is not consistent with BlockGemmShape!");
static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
"Error! WarpGemm's NWarp is not consistent with BlockGemmShape!");
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
"Error! WarpGemm's M is not consistent with BlockGemmShape!");
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
"Error! WarpGemm's N is not consistent with BlockGemmShape!");
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t QScalesPerBlockRow =
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
static constexpr index_t QScalesPerWarpGemmRow =
(WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize;
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static_assert(kQuantGroupSize % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of kQuantGroupSize");
static_assert(QScalesPerWarpGemmRow == 1,
"Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
static_assert(KPerBlock / kQuantGroupSize > 0,
"Error! Each row of blockgemm should have a separate scale");
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
"Error! Warps should cover all Block tile!");
static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
"Error! Warps should cover all Block tile!");
// Currently tested combinations (A, B, BQ)
// 1. fp8, fp8, fp32 -> f32
// 2. bf8, bf8, fp32 -> f32
// 3. i4, fp8, (fp8/fp32) -> f32
// 4. i4, bf8, (fp8/fp32) -> f32
static_assert((std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>) &&
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
(std::is_same_v<BQDataType, float> ||
std::is_same_v<BQDataType, ck_tile::fp8_t> ||
std::is_same_v<BQDataType, ck_tile::bf8_t>) &&
(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>) &&
std::is_same_v<CDataType, fp32_t>);
static constexpr index_t InterWaveSchedulingMacClusters = 1;
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
};
public:
using Traits = GemmTraits_<Problem_, Policy_>;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using BQDataType = remove_cvref_t<typename Traits::BQDataType>;
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using Base = BlockGemmBQuantBase<Problem_>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr auto Scheduler = Traits::Scheduler;
static constexpr uint8_t kA_cvt_scale = std::is_same_v<ADataType, pk_int4_t> ? 16 : 1;
static constexpr uint8_t kB_cvt_scale = std::is_same_v<BDataType, pk_int4_t> ? 16 : 1;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
static constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using I0 = number<0>;
using I1 = number<1>;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, KIterSeq>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, KIterSeq>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
private:
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
struct BlockGemmImpl
{
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
{
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Base::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Base::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
}
// C += A * B
template <typename CBlockTensor,
typename BQBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
BQBlockTensor& bq_block_tensor,
[[maybe_unused]] ASmemBlockWindow& a_block_window,
[[maybe_unused]] BSmemBlockWindow& b_block_window)
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as corresponding "
"C block tensor data type!");
// hot loop:
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
CWarpTensor c_warp_tensor;
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() =
a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});
// Need to multiply bquant with accumulated C
//
// The accumulated C tile has the standard distribution. For example
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
// [26,0], [27,0].
//
// These elements are in different rows, need to get the scale value
// for the corresponding row.
// Based on bquant's tile distribution, it can be inferred which
// lane holds the relevant scale. For example, the scales corresponding
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
//
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
constexpr index_t reg_offset = nIter * Traits::BQPerBlock + kQScale;
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f *
kA_cvt_scale * kB_cvt_scale);
});
});
});
});
}
};
public:
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
}
// C += A * B
template <typename CBlockTensor,
typename BQBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
BQBlockTensor& bq_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_(c_block_tensor, bq_block_tensor, a_block_window, b_block_window);
}
private:
BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
};
} // namespace ck_tile

View File

@@ -0,0 +1,679 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
struct BQuantGemmProblem
{
CK_TILE_HOST BQuantGemmProblem() = default;
CK_TILE_HOST BQuantGemmProblem(index_t M_,
index_t N_,
index_t K_,
index_t QK_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_BQ_)
: M(M_),
N(N_),
K(K_),
QK(QK_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_C(stride_C_),
stride_BQ(stride_BQ_)
{
}
index_t M;
index_t N;
index_t K;
index_t QK;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t stride_BQ;
};
struct BQuantGemmHostArgs : public BQuantGemmProblem
{
CK_TILE_HOST BQuantGemmHostArgs() = default;
CK_TILE_HOST BQuantGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
const void* bq_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t QK_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_BQ_)
: BQuantGemmProblem(M_, N_, K_, QK_, stride_A_, stride_B_, stride_C_, stride_BQ_),
a_ptr(a_ptr_),
b_ptr(b_ptr_),
bq_ptr(bq_ptr_),
c_ptr(c_ptr_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
const void* bq_ptr;
void* c_ptr;
index_t k_batch;
};
struct BQuantGemmKernelArgs
{
const void* a_ptr;
const void* b_ptr;
const void* bq_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t QK;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t stride_BQ;
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BQuantGemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using BQLayout = remove_cvref_t<typename GemmPipeline::BQLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using BQDataType = remove_cvref_t<typename GemmPipeline::BQDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
// clang-format on
}
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr BQuantGemmKernelArgs
MakeKernelArgs(const BQuantGemmHostArgs& hostArgs)
{
return BQuantGemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.bq_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.QK,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C,
hostArgs.stride_BQ,
hostArgs.k_batch};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const BQuantGemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A);
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
}
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = __builtin_amdgcn_readfirstlane(KRead);
}
else
{
splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t splitted_k;
};
CK_TILE_HOST static bool IsSupportedArgument(const BQuantGemmKernelArgs& kargs)
{
if(kargs.k_batch != 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
}
return false;
}
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
if(kargs.QK % GemmPipeline::GetVectorSizeBQ() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
}
return false;
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
return false;
}
if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
}
return false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support M that is not a multiple of MPerBlock without padding!");
}
return false;
}
if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
}
return false;
}
}
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support N that is not a multiple of NPerBlock without padding!");
}
return false;
}
if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
}
return false;
}
}
else
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
return false;
}
if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
}
return false;
}
}
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support N that is not a multiple of NPerBlock without padding!");
}
return false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
}
return false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support M that is not a multiple of MPerBlock without padding!");
}
return false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
}
return false;
}
}
return true;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
const BQuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
const auto& bq_tensor_view = [&]() {
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(kargs.N, kargs.QK),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}();
const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
else
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
return make_tuple(a_tensor_view, bq_tensor_view, b_tensor_view, c_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
}();
const auto& bq_pad_view = [&]() {
const auto& bq_tensor_view = views.at(I1);
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return pad_tensor_view(
bq_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
// TODO: Add support for padding.
sequence<false, false>{});
}();
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I2);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
}();
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I3);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
return make_tuple(a_pad_view, bq_pad_view, b_pad_view, c_pad_view);
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& bq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& c_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& bq_block_window = [&]() {
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
{i_n, 0});
}();
const auto& b_block_window = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(a_block_window, bq_block_window, b_block_window, c_block_window);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const BQuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& bq_block_window = gemm_tile_windows.at(I1);
const auto& b_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
CK_TILE_DEVICE void operator()(BQuantGemmKernelArgs kargs) const
{
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
assert(kargs.k_batch == 1);
RunGemm(a_ptr, b_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
};
} // namespace ck_tile

View File

@@ -28,7 +28,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
static_assert(std::is_same_v<AQLayout, ck_tile::tensor_layout::gemm::RowMajor>);
return GetAQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
return GetABQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
}
template <typename Problem>

View File

@@ -0,0 +1,53 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
{
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize;
static_assert(KPerBlock % QuantGroupSize == 0,
"KPerBlock must be a multiple of QuantGroupSize");
// Create DRAM tile window for BQ
template <typename BQDramBlockWindowTmp>
CK_TILE_DEVICE constexpr auto
GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using YPerTile = number<NPerBlock>;
using XPerTile = number<KPerBlockBQ>;
auto bq_copy_dram_window =
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile(), XPerTile()),
bq_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBQDramTileDistribution<Problem>());
return bq_copy_dram_window;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,93 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "gemm_group_quant_utils.hpp"
namespace ck_tile {
struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy
{
using Base = UniversalGemmPipelineAgBgCrPolicy;
using Base::I0;
using Base::I1;
using Base::I2;
using Base::ATileAccessPattern;
using Base::BTileAccessPattern;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ()
{
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlock, KPerBlockBQ>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution()
{
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using TileEncodingPattern = TileDistributionEncodingPatternBQ<BlockGemmShape,
WarpGemm,
BlockSize,
NPerBlock,
KPerBlockBQ,
VecLoadSize>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
static_assert(std::is_same_v<typename Problem::CDataType, float>);
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BQuantBlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,475 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <typename Problem>
struct BaseBQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
if(has_hot_loop)
{
if(tail_number == ck_tile::TailNumber::Full)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Unsupported tail number for this operation !!!");
}
}
else
{
if(tail_number == ck_tile::TailNumber::Full)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Unsupported tail number for this operation !!!");
}
}
}
};
template <typename Problem, typename Policy = GemmBQuantPipelineAgBgCrDefaultPolicy>
struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmBQuantPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t BQPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetVectorSizeBQ()
{
return Policy::template GetVectorSizeBQ<Problem>();
}
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
using Base::PrefetchStages;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
return concat('_', "bquant_pipeline_AgBgCrCompV3",
concat('x', MPerBlock, NPerBlock, KPerBlock),
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST static std::string Print()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t BQ_Buffer_Load_Inst_Num =
NPerBlock * KPerBlockBQ / (BlockSize * GetVectorSizeBQ());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
auto str = std::stringstream{};
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", "
<< "BQ vector size: " << GetVectorSizeBQ() << "\n"
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
<< ", " << "BQ buffer load inst: " << BQ_Buffer_Load_Inst_Num << "\n"
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename BQDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>> &&
std::is_same_v<BQDataType,
remove_cvref_t<typename BQDramBlockWindowTmp::DataType>>,
"A/B/BQ Dram block window should have the same data type as appropriate "
"([A|B|BQ]DataType) defined in Problem definition!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_bq_col_major =
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
static_assert(NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Bq block window has incorrect lengths for defined BqLayout!");
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex;
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp);
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
using BQBlockTile =
decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
auto block_gemm = BlockGemm();
ABlockTile a_block_tile;
BBlockTile b_block_tile;
BQBlockTile bq_block_tile[2];
int currIdx = 0;
auto c_block_tile = block_gemm.MakeCBlockTile();
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BQDramTileWindowStep bq_dram_tile_window_step =
is_bq_col_major ? make_array(0, KPerBlockBQ) : make_array(KPerBlockBQ, 0);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(
bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
__builtin_amdgcn_sched_barrier(0);
if constexpr(HasHotLoop)
{
index_t i = 0;
do
{
block_sync_lds();
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
bq_copy_dram_window,
bq_dram_tile_window_step);
block_gemm(
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
currIdx = (currIdx + 1) % 2;
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
{
block_gemm(
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
}
else
{
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
bq_copy_dram_window,
bq_dram_tile_window_step);
block_gemm(
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
currIdx = (currIdx + 1) % 2;
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
}
return c_block_tile;
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename BQDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
bq_dram_block_window_tmp,
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -8,7 +8,7 @@
namespace ck_tile {
template <typename Problem, typename DataType, index_t YPerTile, index_t XPerTile>
CK_TILE_HOST_DEVICE static constexpr auto GetAQGlobalVectorLoadSize()
CK_TILE_HOST_DEVICE static constexpr auto GetABQGlobalVectorLoadSize()
{
using I1 = number<1>;
constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});
@@ -164,4 +164,56 @@ struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEnc
}
};
// TODO:: might need to update
template <typename BlockGemmShape,
typename WarpGemm,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize>
struct TileDistributionEncodingPatternBQ : public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
static constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN);
static_assert(num_warps == MWarps * NWarps * KWarps);
// KWarps > 1 isn't supported
static_assert(KWarps == 1);
// # of elements per thread
static constexpr index_t X = XPerTile;
static constexpr index_t XR = 2;
// Number of iters per warp
// MIters are indexed using (Y0, Y1)
static constexpr index_t Y0 = NIterPerWarp;
// # of warps in Y dim
static constexpr index_t Y1 = NWarps;
static constexpr index_t Y2 = WarpGemm::kN;
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
};
} // namespace ck_tile

View File

@@ -121,4 +121,107 @@ using GemmAQuantPipelineProblem = GemmAQuantPipelineProblemBase<ADataType_,
HasHotLoop_,
TailNum_>;
template <typename ADataType_,
typename BDataType_,
typename BQDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
typename ComputeDataType_ = ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct GemmBQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>
{
using Base = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>;
using Traits = typename Base::Traits;
using typename Base::ADataType;
using typename Base::BDataType;
using typename Base::CDataType;
using typename Base::ComputeDataType;
using BQDataType = remove_cvref_t<BQDataType_>;
using BlockGemmShape = typename Base::BlockGemmShape;
using typename Base::ALayout;
using typename Base::BLayout;
using typename Base::CLayout;
static constexpr bool TransposeC = Traits::TransposeC;
using Base::kBlockSize;
using Base::kPadK;
using Base::kPadM;
using Base::kPadN;
using Base::DoubleSmemBuffer;
using Base::VectorLoadSize;
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
static_assert(Scheduler == GemmPipelineScheduler::Intrawave);
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_bquant_problem",
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler,
"QuantGroupSize",
kQuantGroupSize);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
{
return VectorLoadSize / sizeof(BQDataType);
}
static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
};
template <typename ADataType_,
typename BDataType_,
typename BQDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
typename ComputeDataType_ = ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
using GemmBQuantPipelineProblem = GemmBQuantPipelineProblemBase<ADataType_,
BDataType_,
BQDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
} // namespace ck_tile

View File

@@ -28,6 +28,35 @@ struct TileGemmAQuantTraits
using CLayout = CLayout_;
using AQLayout = AQLayout_;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
};
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool PreshuffleQuant_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
typename BQLayout_ = BLayout_>
struct TileGemmBQuantTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = 16;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
using BQLayout = BQLayout_;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;