Merge commit 'c363a98d4154c647c1a2d5331ad0d76879b84dfa' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-08 21:13:22 +00:00
parent 564276eff9
commit 375e499d10
22 changed files with 1307 additions and 395 deletions

View File

@@ -9,14 +9,190 @@
#include <string>
#include <tuple>
#include <memory>
#include <type_traits>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
#include "ck_tile/host.hpp"
#include "quant_grouped_gemm.hpp"
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename QuantGroupSize,
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode,
AQLayout,
BQLayout,
GemmConfig::TransposeC,
GemmConfig::DoubleSmemBuffer,
GemmConfig::Persistent>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
GemmConfig::PreshuffleB>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
GemmConfig::TransposeC,
BDataType,
scheduler,
has_hot_loop_v,
tail_number_v>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
ADataType,
scheduler,
has_hot_loop_v,
tail_number_v>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
GemmConfig::TransposeC,
BDataType,
scheduler,
has_hot_loop_v,
tail_number_v>>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
@@ -59,41 +235,48 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
BQLayout,
GemmConfig::TransposeC,
GemmConfig::DoubleSmemBuffer,
true>; // Persistence
GemmConfig::Persistent>;
float ave_time{0};
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
constexpr bool transpose_c = false;
using QuantGemmProblem = typename std::conditional<
QuantMode == ck_tile::QuantType::BQuantGrouped,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>,
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
GemmConfig::TransposeC>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
GemmConfig::TransposeC,
BDataType,
scheduler>>::type;
scheduler>>;
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -146,6 +329,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
int main(int argc, char* argv[])
{
int result1 = !run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
int result1 = run_grouped_gemm_example(argc, argv);
return result1;
}

View File

@@ -64,6 +64,7 @@ struct GemmTypeConfig<ck_tile::bf8_t>
using CDataType = ck_tile::half_t;
};
template <bool Persistent_>
struct GemmConfigBase
{
static constexpr bool kPadM = false;
@@ -83,10 +84,11 @@ struct GemmConfigBase
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool PreshuffleB = false;
static constexpr bool Persistent = Persistent_;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
template <typename PrecType, bool Persistent>
struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -101,8 +103,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
template <typename PrecType, bool Persistent>
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -121,6 +123,66 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
static constexpr bool DoubleSmemBuffer = true;
};
template <ck_tile::QuantType QuantMode>
struct GemmQuantConfig;
template <>
struct GemmQuantConfig<ck_tile::QuantType::TensorQuant>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
};
template <>
struct GemmQuantConfig<ck_tile::QuantType::RowColQuant>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
};
template <>
struct GemmQuantConfig<ck_tile::QuantType::AQuantGrouped>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
};
template <>
struct GemmQuantConfig<ck_tile::QuantType::BQuantGrouped>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<GemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<GemmProblem>>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline =
std::conditional_t<PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>>;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
@@ -148,8 +210,9 @@ auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
.insert("kbatch", "1", "kbatch for SplitK")
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol")
.insert("init", "0", "0. Random, 2. One(s) (Constant)");
.insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol")
.insert("init", "0", "0. Random, 2. One(s) (Constant)")
.insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -57,56 +57,83 @@ float invoke_gemm(int n_warmup,
float ave_time = 0;
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
// the gemm problems known on the host. Instead, we can just pass the pointer
// to the kernel and let the workgroups figure out which tiles to work on.
// This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
assert(args[0].k_batch == 1);
for(const auto& arg : args)
if constexpr(!GemmConfig::Persistent)
{
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.aq_ptr,
arg.bq_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.QK_A,
arg.QK_B,
arg.stride_A,
arg.stride_B,
arg.stride_E,
arg.stride_AQ,
arg.stride_BQ,
arg.k_batch});
ave_time =
grouped_gemm<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
QuantGroupSize,
QuantMode>(args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
gemm_workspace.GetDeviceBuffer());
}
else
{
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
// the gemm problems known on the host. Instead, we can just pass the pointer
// to the kernel and let the workgroups figure out which tiles to work on.
// This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
if(args[0].k_batch != 1)
{
throw std::runtime_error("Split-K not supported yet for persistent kernel");
}
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.aq_ptr,
arg.bq_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.QK_A,
arg.QK_B,
arg.stride_A,
arg.stride_B,
arg.stride_E,
arg.stride_AQ,
arg.stride_BQ,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
QuantGroupSize,
QuantMode>(stream, group_count, kargs_ptr);
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
QuantGroupSize,
QuantMode>(stream, group_count, kargs_ptr);
std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")";
@@ -259,13 +286,24 @@ int run_grouped_gemm_example_with_layouts(int argc,
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error(
"K must be divisible by QuantGroupSize::kK for AQuantGrouped mode");
}
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
throw std::runtime_error(
"K must be divisible by QuantGroupSize::kK for BQuantGrouped mode");
}
}
@@ -284,6 +322,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
stride_AQs[i] =
ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
stride_BQs[i] = 0; // No B quantization
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
stride_AQs[i] = 0; // No A quantization
@@ -311,10 +355,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout))));
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(0, 0, stride_BQs[i], is_row_major(bq_layout))));
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout))));
ck_tile::host_tensor_descriptor(0, 0, stride_AQs[i], is_row_major(aq_layout))));
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
}
@@ -444,7 +495,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
bq_tensors[i],
c_m_n_host_ref);
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
@@ -452,6 +503,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
AccDataType,
CDataType,
QuantGroupSize,
true>(
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
@@ -477,7 +539,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
return pass;
}
template <typename GemmConfig, typename PrecType, ck_tile::QuantType QuantMode>
template <typename PrecType, ck_tile::QuantType QuantMode, typename GemmConfig>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
@@ -494,6 +556,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
@@ -511,7 +574,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
}
}
template <template <typename PrecType> typename GemmConfig>
template <typename PrecType, ck_tile::QuantType QuantMode>
int run_gemm_example_persistency(
std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[])
{
if(persistent)
{
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, true>;
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
a_layout, b_layout, argc, argv);
}
else
{
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, false>;
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
a_layout, b_layout, argc, argv);
}
}
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -524,29 +604,29 @@ int run_grouped_gemm_example(int argc, char* argv[])
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
std::string quant_mode = arg_parser.get_str("quant_mode");
bool persistent = arg_parser.get_bool("persistent");
if(data_type == "fp8")
{
if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "aquant")
{
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, persistent, argc, argv);
}
else
{
@@ -557,24 +637,23 @@ int run_grouped_gemm_example(int argc, char* argv[])
{
if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "aquant")
{
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, persistent, argc, argv);
}
else
{

View File

@@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
@@ -154,6 +155,10 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using Base = BlockGemmBQuantBase<Problem_>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
@@ -271,12 +276,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
// If B datatype were pkint4 it would be converted prior to storing in LDS
load_int4_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
}
// C += A * B
@@ -397,11 +410,16 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
MakeCBlockTile();
}
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}
// C += A * B

View File

@@ -426,7 +426,6 @@ struct QuantGemmKernel
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
@@ -781,7 +780,9 @@ struct QuantGemmKernel
{
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"PreshuffleQuant with BQuantGrouped currently only supports "
"ColumnMajor BQ layout");
return MakePreshuffledQuantTensorView<
GemmPipeline::KPerBlockBQ,
@@ -791,14 +792,35 @@ struct QuantGemmKernel
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
// For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN]
// Dimensions: [K/QuantGroupK, N/QuantGroupN]
// Strides: [N/QuantGroupN, 1]
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
// For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK]
// Dimensions: [N/QuantGroupN, K/QuantGroupK]
// Strides: [K/QuantGroupK, 1]
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
}
}
else
@@ -1023,10 +1045,10 @@ struct QuantGemmKernel
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
@@ -1042,13 +1064,23 @@ struct QuantGemmKernel
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
{0, i_n / QuantGroupSize::kN});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
}
}
else

View File

@@ -163,7 +163,6 @@ struct QuantGroupedGemmKernel
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -262,10 +261,9 @@ struct QuantGroupedGemmKernel
auto karg =
QuantGroupedGemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
type_convert<CDataType*>(gemm_descs[i].e_ptr),
type_convert<const AQDataType*>(gemm_descs[i].aq_ptr),
type_convert<const BQDataType*>(gemm_descs[i].bq_ptr),
gemm_descs[i].k_batch,
type_convert<CDataType*>(gemm_descs[i].e_ptr),
M,
N,
K,
@@ -275,7 +273,8 @@ struct QuantGroupedGemmKernel
stride_b,
stride_e,
gemm_descs[i].stride_AQ,
gemm_descs[i].stride_BQ};
gemm_descs[i].stride_BQ,
gemm_descs[i].k_batch};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
}
@@ -342,16 +341,32 @@ struct QuantGroupedGemmKernel
else
{
RunGemmWithPipelineSelection(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
if constexpr(UsePersistentKernel)
{
RunGemmWithPipelineSelection(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else // Non-persistent kernel
{
Base::RunGemm({a_ptr},
{b_ptr},
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
}
@@ -451,7 +466,24 @@ struct QuantGroupedGemmKernel
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
if constexpr(kQuantType == QuantType::BQuantGrouped)
if constexpr(kQuantType == QuantType::AQuantGrouped)
{
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
aq_block_window,
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
auto& c_block_window = gemm_tile_windows.at(Base::I4);
// Run Epilogue Pipeline
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
// Run GEMM pipeline
@@ -496,6 +528,53 @@ struct QuantGroupedGemmKernel
}
}
CK_TILE_DEVICE index_t FindGroupId(const QuantGemmTransKernelArg* gemm_desc_ptr,
index_t block_id,
index_t group_count) const
{
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) >> 1);
while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
block_id < gemm_desc_ptr[group_id].block_end)) &&
left <= right)
{
if(block_id < gemm_desc_ptr[group_id].block_start)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) >> 1);
}
return group_id;
}
// For non-persistent kernels
template <bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
index_t group_count) const
{
const index_t block_id = ck_tile::get_block_1d_id();
const auto gemm_desc_ptr = reinterpret_cast<const QuantGemmTransKernelArg*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
const auto& kargs = gemm_desc_ptr[group_id];
const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0,
kargs.group_karg.M,
kargs.group_karg.N,
(block_id - kargs.block_start) % grid_size_2d);
Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
}
// For persistent kernels
template <bool U = UsePersistentKernel,
typename = std::enable_if_t<U>,

View File

@@ -319,6 +319,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
if constexpr(HasHotLoop)
{
constexpr index_t tail_count =
((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) ? 1 : 2;
index_t i = 0;
do
{
@@ -366,7 +368,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
} while(i < (num_loop - tail_count));
}
// tail
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
@@ -439,6 +441,51 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
num_loop,
p_smem);
}
/// @brief Runtime pipeline dispatch operator for grouped GEMM kernels.
///
/// This operator is used by grouped GEMM kernels where pipeline parameters
/// (has_hot_loop, num_loop, tail_number) are calculated on the device side
/// at runtime, not on the host side during compilation. This is necessary
/// because different GEMM problems in the group may have different K dimensions,
/// requiring different pipeline configurations that cannot be determined at
/// compile time.
///
/// @param a_dram_block_window_tmp Block window for A tensor in DRAM
/// @param b_dram_block_window_tmp Block window for B tensor in DRAM
/// @param aq_dram_block_window_tmp Block window for AQ (quantization scale) tensor in DRAM
/// @param num_loop Number of main loop iterations (calculated on device)
/// @param has_hot_loop Whether the pipeline has a hot loop (calculated on device)
/// @param tail_number Type of tail handling required (calculated on device)
/// @param p_smem Pointer to shared memory
/// @return Accumulated result tile in registers
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AQDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* p_smem,
index_t m = 0) const
{
const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) {
constexpr bool hot_loop = has_hot_loop_.value;
constexpr auto tail_num = tail_number_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,
m, // dummy value, won't be used
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};
} // namespace ck_tile

View File

@@ -42,14 +42,18 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
CK_TILE_DEVICE constexpr auto
GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using YPerTile = number<NPerBlockBQ>;
using XPerTile = number<KPerBlockBQ>;
using YPerTile =
std::conditional_t<std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
number<NPerBlockBQ>,
number<KPerBlockBQ>>;
using XPerTile =
std::conditional_t<std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
number<KPerBlockBQ>,
number<NPerBlockBQ>>;
auto bq_copy_dram_window =
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile(), XPerTile()),
make_tuple(YPerTile{}, XPerTile{}),
bq_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBQDramTileDistribution<Problem>());
return bq_copy_dram_window;

View File

@@ -25,8 +25,16 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
// Support both RowMajor and ColumnMajor layouts for BQ
if constexpr(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetABQGlobalVectorLoadSize<Problem, BQDataType, KPerBlockBQ, NPerBlockBQ>();
}
else
{
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
}
}
template <typename Problem>
@@ -52,7 +60,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpTile::at(I2),
Problem::TransposeC>;
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
if constexpr(PreshuffleQuant)
{
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<
@@ -62,18 +69,21 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
NPerBlock / WarpGemm::kN,
ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()),
VecLoadSize,
BQLayout,
PreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
else
{
// KPerTile and NPerTile are LOGICAL dimensions (K quant groups and N quant groups)
using TileEncodingPattern =
tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
NPerBlockBQ,
KPerBlockBQ,
Problem::QuantGroupSize::kN>;
KPerBlockBQ, // Logical K dimension
NPerBlockBQ, // Logical N dimension
Problem::QuantGroupSize::kN,
BQLayout>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}

View File

@@ -33,6 +33,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
@@ -83,6 +87,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
using Base::PrefetchStages;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
@@ -125,7 +132,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t BQ_Buffer_Load_Inst_Num =
NPerBlock * KPerBlockBQ / (BlockSize * GetVectorSizeBQ());
NPerBlockBQ * KPerBlockBQ / (BlockSize * GetVectorSizeBQ());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
@@ -167,6 +174,16 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
using Base = PipelineImplBase;
template <typename BDramWindow, typename BBlockTile_>
CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile,
const BDramWindow& b_dram_window)
{
using DestDataType = typename BBlockTile_::DataType;
using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
@@ -194,11 +211,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_bq_col_major =
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
constexpr bool is_bq_row_major =
std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
@@ -212,12 +227,22 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
static_assert(
PreshuffleQuant ||
(is_bq_row_major
? (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])),
"Bq block window has incorrect lengths for defined BqLayout!");
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex;
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
// Note: BDataType PkInt4 gets converted during loading, before going to LDS
auto&& [a_lds_block, b_lds_block] =
Base::template GetABLdsTensorViews<ADataType, OverrideBDataType>(p_smem);
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
@@ -237,7 +262,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
decltype(make_static_distributed_tensor<ADataType>(BBlockTileDistr{}));
using BQBlockTile =
decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
@@ -258,18 +283,20 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0)
: is_bq_col_major ? make_array(0, KPerBlockBQ)
: make_array(KPerBlockBQ, 0);
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
: make_array(0, KPerBlockBQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// B tile gets converted to A datatype during loading
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(
bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -281,9 +308,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
// B datatype is converted to A datatype during loading
auto b_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -294,11 +322,13 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
__builtin_amdgcn_sched_barrier(0);
@@ -311,7 +341,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
block_sync_lds();
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -322,9 +352,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
// Note: BDataType PkInt4 gets converted during loading earlier
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -335,7 +366,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
bq_copy_dram_window,
bq_dram_tile_window_step);
@@ -347,7 +379,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
__builtin_amdgcn_sched_barrier(0);
i += 1;
@@ -383,7 +416,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
// Note: BDataType gets converted during loading from PkInt4
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -393,7 +427,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
}
@@ -415,7 +450,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
// Note: BDataType PkInt4 gets converted during loading
[](const OverrideBDataType& b) { return b; },
bq_dram_block_window_tmp,
n,
num_loop,
@@ -458,7 +494,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
// Note: BDataType PkInt4 gets converted during loading
[](const OverrideBDataType& b) { return b; },
bq_dram_block_window_tmp,
n, // dummy value, won't be used
num_loop,

View File

@@ -189,9 +189,10 @@ struct tile_distribution_encoding_pattern_aq_transposed_c
template <typename BlockGemmShape,
typename WarpGemm,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t YPerQ,
index_t KPerTile,
index_t NPerTile,
index_t NPerQ,
typename BQLayout = tensor_layout::gemm::ColumnMajor,
bool PreshuffleQuant = false>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
{
@@ -210,36 +211,41 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
///
/// This function determines the optimal thread distribution pattern for loading and applying
/// quantization scales to the B matrix based on the quantization group size (XPerQ) relative
/// quantization scales to the B matrix based on the quantization group size (NPerQ) relative
/// to warp dimensions.
///
/// Three distinct distribution patterns are handled:
///
/// 1. Fine-grained quantization (XPerQ < WarpGemm::kN):
/// 1. Fine-grained quantization (NPerQ < WarpGemm::kN):
/// - Multiple quantization groups exist within a single warp's N-dimension
/// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp)
/// - Distribution includes explicit replication factor (XR = XPerQ) for scale broadcast
/// - Example: XPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
/// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp)
/// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast
/// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
///
/// 2. Medium-grained quantization (WarpGemm::kN <= XPerQ <= WarpGemm::kN * NWarps):
/// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps):
/// - Each warp handles exactly one quantization scale
/// - Scales are distributed across warps with replication factor XR = XPerQ / WarpGemm::kN
/// - Example: XPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
/// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN
/// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
///
/// 3. Coarse-grained quantization (XPerQ > WarpGemm::kN * NWarps):
/// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps):
/// - Quantization group spans multiple warps
/// - All warps share the same scale value
/// - Example: XPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
/// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
///
/// @return A static tile distribution encoding for the BQ scale tensor
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
// Preshuffle only supported for ColumnMajor currently
static_assert(!(PreshuffleQuant && std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>),
"PreshuffleQuant only supported for ColumnMajor BQLayout");
if constexpr(PreshuffleQuant)
{
// ColumnMajor only for preshuffle
constexpr index_t X1 = warp_size;
constexpr index_t X0 = XPerTile / warp_size;
constexpr index_t X0 = NPerTile / warp_size;
constexpr index_t Y1 = NWarps;
constexpr index_t Y0 = YPerTile / Y1;
constexpr index_t Y0 = KPerTile / Y1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps>,
@@ -251,52 +257,97 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
}
else
{
if constexpr(YPerQ < WarpGemm::kN)
if constexpr(NPerQ < WarpGemm::kN)
{
// Case 1: Fine-grained - multiple quantization scales within a single warp
constexpr index_t X = XPerTile; // Full X dimension of tile
constexpr index_t XR = 1; // No Y replication needed
constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t Y1 = NWarps; // Number of warps in N-dim
constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp
constexpr index_t YR = YPerQ; // Elements per quantization group
// N dimension needs to be partitioned the same way regardless of layout
constexpr index_t NR = 1; // No N replication needed
constexpr index_t N0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t N1 = NWarps; // Number of warps in N-dim
constexpr index_t N2 = WarpGemm::kN / NPerQ; // Number of scales per warp
static_assert(Y0 * Y1 * Y2 == YPerTile,
"Y0, Y1, Y2 must cover the blocktile along Y.");
static_assert(N0 * N1 * N2 == NPerTile,
"N0, N1, N2 must cover the blocktile along N dimension.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR, YR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<0, 1>, sequence<0, 1, 0>>,
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
{
// ColumnMajor: [(N0, N1, N2), K] - N on Y-axis, partition Y
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR, NPerQ>,
tuple<sequence<N0, N1, N2>, sequence<KPerTile>>,
tuple<sequence<0, 1>, sequence<0, 1, 0>>,
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else
{
// RowMajor: [K, (N0, N1, N2)] - N on X-axis, partition X
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR, NPerQ>,
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
tuple<sequence<0, 2>, sequence<0, 2, 0>>,
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
}
else if constexpr(YPerQ <= WarpGemm::kN * NWarps)
else if constexpr(NPerQ <= WarpGemm::kN * NWarps)
{
// Case 2: Medium-grained - one quantization scale per warp
constexpr auto YR = YPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto Y1 = NWarps / YR; // Warps per unique scale
constexpr auto Y0 = YPerTile / Y1; // Iterations to cover X dimension
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, YR, get_warp_size()>,
tuple<sequence<Y0, Y1>, sequence<XPerTile>>,
tuple<sequence<0, 1, 0>, sequence<0>>,
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<1, 2>,
sequence<0, 0>>{});
constexpr auto NR = NPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto N1 = NWarps / NR; // Warps per unique scale
constexpr auto N0 = NPerTile / N1; // Iterations to cover N dimension
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
{
// ColumnMajor: [(N0, N1), K] - N on Y-axis
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR, get_warp_size()>,
tuple<sequence<N0, N1>, sequence<KPerTile>>,
tuple<sequence<0, 1, 0>, sequence<0>>,
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else
{
// RowMajor: [K, (N0, N1)] - N on X-axis
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR, get_warp_size()>,
tuple<sequence<KPerTile>, sequence<N0, N1>>,
tuple<sequence<0, 2, 0>, sequence<0>>,
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
}
else // XPerQ > WarpGemm::kN * NWarps
else // NPerQ > WarpGemm::kN * NWarps
{
// Case 3: Coarse-grained - quantization group spans all warps
// All warps in N-dimension share the same quantization scale
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<XPerTile>>,
tuple<sequence<0, 0>, sequence<0>>,
tuple<sequence<0, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
{
// ColumnMajor: [N, K]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
tuple<sequence<NPerTile>, sequence<KPerTile>>,
tuple<sequence<0, 0>, sequence<0>>,
tuple<sequence<0, 1>, sequence<2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else
{
// RowMajor: [K, N]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
tuple<sequence<KPerTile>, sequence<NPerTile>>,
tuple<sequence<0, 0>, sequence<0>>,
tuple<sequence<0, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
}
}
}

View File

@@ -86,8 +86,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
// BQLayout is always ColumnMajor for BQuant
using BQLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
// Re-use the AQLayout for BQLayout
using BQLayout = AQLayout;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,

View File

@@ -28,42 +28,58 @@ using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for BQuant tests (without PreshuffleB)
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuantTypes = ::testing::Types<
// 1d cases with grouping only on k axis (AQLayout is always RowMajor for BQuant)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
// 1d cases with grouping only on k axis
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
// 2d cases with grouping also on the n axis
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D128N>,
// some cases with transpose layouts
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
// pkint4 + transpose cases
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<ColumnMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>
>;
// clang-format on

View File

@@ -26,60 +26,60 @@ using GroupSize2D32N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
using GroupSize2D64N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
// Type combinations for BQuant tests with PreshuffleB
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BPreshuffleBQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
// //2d cases with preshuffle B
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBDecode, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D8N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D32N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize2D64N>
>;
// clang-format on

View File

@@ -389,6 +389,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
using typename Base::QDataType;
using typename Base::QuantGroupSize;
// Re-use AQLayout from tuple parameters as BQLayout
using BQLayout = typename Base::AQLayout;
static constexpr auto QuantType = Base::QuantType;
static constexpr auto PreshuffleB = Base::PreshuffleB;
static constexpr auto TiledMMAPermuteN = Base::TiledMMAPermuteN;
@@ -406,16 +409,15 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
// BQuant uses block/grouped quantization for B matrix
const ck_tile::index_t BQN = ck_tile::integer_divide_ceil(N, QuantGroupSize::kN);
const ck_tile::index_t BQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK);
const ck_tile::index_t stride_BQ = BQK;
const ck_tile::index_t stride_BQ = this->is_row_major(BQLayout{}) ? BQN : BQK;
// Generate test data
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
// BQ is always ColumnMajor
ck_tile::HostTensor<QDataType> bq_bqk_bqn(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, ck_tile::bool_constant<false>{}));
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{})));
// Initialize data with random values
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);

View File

@@ -14,6 +14,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -18,32 +18,41 @@ using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
using AQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>
>;
// clang-format on

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util_quant.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
using AQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
// clang-format off
using KernelTypes_AQuant = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, True>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_AQuant, KernelTypes_AQuant);
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_AQuant
#include "test_grouped_gemm_quant_ut_cases.inc"
#undef TEST_CLASS_NAME

View File

@@ -20,9 +20,14 @@ using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQ
// clang-format off
using KernelTypes_BQuant = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, False, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False>
>;
// clang-format on

View File

@@ -20,11 +20,14 @@ using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantTyp
// clang-format off
using KernelTypes_RowCol = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False>
>;
// clang-format on

View File

@@ -20,11 +20,14 @@ using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantTyp
// clang-format off
using KernelTypes_Tensor = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>
>;
// clang-format on

View File

@@ -3,6 +3,7 @@
#pragma once
#include <sstream>
#include <gtest/gtest.h>
#include <type_traits>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
@@ -32,24 +33,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using AQLayout = Row;
using BQLayout = Col;
static constexpr bool Persistent = true;
static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value;
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
static constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value;
static constexpr bool TransposeC = std::tuple_element_t<12, Tuple>::value;
struct GroupedGemKernelParam_Mfma
{
@@ -66,11 +52,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
static const ck_tile::index_t N_Warp = 2;
static const ck_tile::index_t K_Warp = 1;
static const ck_tile::index_t M_Warp_Tile = 32;
static const ck_tile::index_t N_Warp_Tile = 32;
static const ck_tile::index_t K_Warp_Tile =
TestCkTileGroupedGemmQuant::template get_k_from_preshuffled_warp_tile<BDataType,
M_Warp_Tile>();
static const ck_tile::index_t M_Warp_Tile = 16;
static const ck_tile::index_t N_Warp_Tile = 16;
static const ck_tile::index_t K_Warp_Tile = 32;
};
struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma
@@ -90,16 +74,201 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
}
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
float invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
constexpr bool DoubleSmemBuffer =
PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped ||
QuantType == ck_tile::QuantType::BQuantGrouped;
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
GroupedGemKernelParam::N_Tile,
GroupedGemKernelParam::K_Tile>,
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::K_Warp>,
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
false,
PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantType,
AQLayout,
BQLayout,
TransposeC,
DoubleSmemBuffer,
Persistent>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = std::conditional_t<
UseGroupedQuant,
std::conditional_t<
QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<
PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile;
const ck_tile::index_t K_split =
(gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
TransposeC,
BDataType,
scheduler,
has_hot_loop_v,
tail_number_v>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
ADataType,
scheduler,
has_hot_loop_v,
tail_number_v>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
TransposeC,
BDataType,
scheduler,
has_hot_loop_v,
tail_number_v>>;
using GemmPipeline = std::conditional_t<
UseGroupedQuant,
std::conditional_t<
QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName()
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GroupedGemKernelParam::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr)
{
constexpr bool TransposeC = false;
constexpr bool DoubleSmemBuffer =
PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
@@ -131,40 +300,53 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
BQLayout,
TransposeC,
DoubleSmemBuffer,
true>;
Persistent>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr auto memory_operation = memory_operation_.value;
constexpr bool transpose_c = false;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using QuantGemmProblem = typename std::conditional<
QuantType == ck_tile::QuantType::BQuantGrouped,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>,
constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped ||
QuantType == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
TransposeC>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
TransposeC,
BDataType,
scheduler>>::type;
scheduler>>;
using GemmPipeline = std::conditional_t<
QuantType == ck_tile::QuantType::RowColQuant ||
QuantType == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
UseGroupedQuant,
std::conditional_t<
QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -199,7 +381,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockPerCu>(
ck_tile::make_kernel<GroupedGemKernelParam::kBlockPerCu>(
Kernel{},
grids,
blocks,
@@ -292,13 +474,24 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
}
else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped)
{
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error(
"K must be divisible by QuantGroupSize::kK for AQuantGrouped mode");
}
}
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / 128; // Group quantization: BQK = K / GroupSize
if(K % 128 != 0)
AQK = 0; // No A quantization
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
throw std::runtime_error(
"K must be divisible by QuantGroupSize::kK for BQuantGrouped mode");
}
}
@@ -317,6 +510,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
}
else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped)
{
stride_AQs[i] =
ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout()));
stride_BQs[i] = 0; // No B quantization
}
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
{
stride_AQs[i] = 0; // No A quantization
@@ -348,11 +547,20 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
1, 1, stride_BQs[i], is_row_major(BQLayout()))));
}
else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped)
{
aq_tensors.push_back(
ck_tile::HostTensor<AQDataType>(ck_tile::host_tensor_descriptor(
M, AQK, stride_AQs[i], is_row_major(AQLayout{}))));
bq_tensors.push_back(
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
0, 0, stride_BQs[i], is_row_major(BQLayout()))));
}
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
{
aq_tensors.push_back(
ck_tile::HostTensor<AQDataType>(ck_tile::host_tensor_descriptor(
0, AQK, stride_AQs[i], is_row_major(AQLayout{}))));
0, 0, stride_AQs[i], is_row_major(AQLayout{}))));
bq_tensors.push_back(
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
BQK, N, stride_BQs[i], is_row_major(BQLayout()))));
@@ -429,11 +637,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
if constexpr(Persistent)
{
// Generate kernel arguments
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
assert(gemm_descs[0].k_batch == 1);
for(const auto& arg : gemm_descs)
{
@@ -471,7 +680,14 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
}
else
{
GTEST_FAIL() << "Non-persistent kernel not implemented yet";
const auto stream = ck_tile::stream_config{nullptr, false, 1};
#if CK_TILE_USE_WMMA
invoke_grouped_gemm<GroupedGemKernelParam_Wmma, ALayout, BLayout, CLayout>(
gemm_descs, stream, kargs_ptr);
#else
invoke_grouped_gemm<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
gemm_descs, stream, kargs_ptr);
#endif
}
// Copy results back to host for validation
@@ -512,7 +728,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
bq_tensors[i],
c_m_n_host_ref);
}
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
@@ -520,6 +736,17 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
AccDataType,
CDataType,
QuantGroupSize,
true>(
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
@@ -550,5 +777,8 @@ using TestCkTileGroupedGemmQuant_RowCol = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_Tensor = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant<Tuple>;