[CK Tile] Grouped GEMM aquant mode and non-persistent kernel (#3337)

* wip: add aquant to grouped gemm quant example

* fix: properly handle hot loop count in aquant pipeline

* fix: add separate GemmConfig structs for AQuant, automatically select the correct one

* feat: finish support for a non-persistent kernel invocation for grouped gemm quant, and add support code to example

* refactor: cleaned up grouped gemm quant example a bit by reusing pipeline selection logic

* chore: add warp gemm dispatchers for a couple of TransposeC K=32 variants

* feat: add quant grouped gemm tests cases for aquant (regular and transpose C) and non-persistent kernel

* fix: update base pipeline classes according to changes in develop branch

* Revert "chore: add warp gemm dispatchers for a couple of TransposeC K=32 variants"

This reverts commit b3fd4d326d.

* feat: remove aquant config from grouped gemm quant example, update to add persistency as runtime parameter

* chore: removed work-around for aquant bug that has been fixed

* chore: fix typo in command-line parameters

* fix: correct K warp tile size for gfx950

* chore: incorrect warp tile configuration on gfx942
This commit is contained in:
Erwin Terpstra
2025-12-08 21:19:22 +01:00
committed by GitHub
parent ca6143f0b2
commit fe07b5a1bf
12 changed files with 948 additions and 206 deletions

View File

@@ -9,14 +9,190 @@
#include <string>
#include <tuple>
#include <memory>
#include <type_traits>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
#include "ck_tile/host.hpp"
#include "quant_grouped_gemm.hpp"
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename QuantGroupSize,
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode,
AQLayout,
BQLayout,
GemmConfig::TransposeC,
GemmConfig::DoubleSmemBuffer,
GemmConfig::Persistent>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
GemmConfig::PreshuffleB>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
GemmConfig::TransposeC,
BDataType,
scheduler,
has_hot_loop_v,
tail_number_v>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
ADataType,
scheduler,
has_hot_loop_v,
tail_number_v>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
GemmConfig::TransposeC,
BDataType,
scheduler,
has_hot_loop_v,
tail_number_v>>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
@@ -59,41 +235,48 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
BQLayout,
GemmConfig::TransposeC,
GemmConfig::DoubleSmemBuffer,
true>; // Persistence
GemmConfig::Persistent>;
float ave_time{0};
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
constexpr bool transpose_c = false;
using QuantGemmProblem = typename std::conditional<
QuantMode == ck_tile::QuantType::BQuantGrouped,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>,
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
GemmConfig::TransposeC>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
GemmConfig::TransposeC,
BDataType,
scheduler>>::type;
scheduler>>;
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -146,6 +329,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
int main(int argc, char* argv[])
{
int result1 = !run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
int result1 = run_grouped_gemm_example(argc, argv);
return result1;
}

View File

@@ -64,6 +64,7 @@ struct GemmTypeConfig<ck_tile::bf8_t>
using CDataType = ck_tile::half_t;
};
template <bool Persistent_>
struct GemmConfigBase
{
static constexpr bool kPadM = false;
@@ -83,10 +84,11 @@ struct GemmConfigBase
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool PreshuffleB = false;
static constexpr bool Persistent = Persistent_;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
template <typename PrecType, bool Persistent>
struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -101,8 +103,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
template <typename PrecType, bool Persistent>
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -121,6 +123,66 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
static constexpr bool DoubleSmemBuffer = true;
};
template <ck_tile::QuantType QuantMode>
struct GemmQuantConfig;
template <>
struct GemmQuantConfig<ck_tile::QuantType::TensorQuant>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
};
template <>
struct GemmQuantConfig<ck_tile::QuantType::RowColQuant>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
};
template <>
struct GemmQuantConfig<ck_tile::QuantType::AQuantGrouped>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
};
template <>
struct GemmQuantConfig<ck_tile::QuantType::BQuantGrouped>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<GemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<GemmProblem>>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline =
std::conditional_t<PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>>;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
@@ -148,8 +210,9 @@ auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
.insert("kbatch", "1", "kbatch for SplitK")
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol")
.insert("init", "0", "0. Random, 2. One(s) (Constant)");
.insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol")
.insert("init", "0", "0. Random, 2. One(s) (Constant)")
.insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -57,56 +57,83 @@ float invoke_gemm(int n_warmup,
float ave_time = 0;
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
// the gemm problems known on the host. Instead, we can just pass the pointer
// to the kernel and let the workgroups figure out which tiles to work on.
// This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
assert(args[0].k_batch == 1);
for(const auto& arg : args)
if constexpr(!GemmConfig::Persistent)
{
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.aq_ptr,
arg.bq_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.QK_A,
arg.QK_B,
arg.stride_A,
arg.stride_B,
arg.stride_E,
arg.stride_AQ,
arg.stride_BQ,
arg.k_batch});
ave_time =
grouped_gemm<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
QuantGroupSize,
QuantMode>(args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
gemm_workspace.GetDeviceBuffer());
}
else
{
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
// the gemm problems known on the host. Instead, we can just pass the pointer
// to the kernel and let the workgroups figure out which tiles to work on.
// This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
if(args[0].k_batch != 1)
{
throw std::runtime_error("Split-K not supported yet for persistent kernel");
}
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.aq_ptr,
arg.bq_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.QK_A,
arg.QK_B,
arg.stride_A,
arg.stride_B,
arg.stride_E,
arg.stride_AQ,
arg.stride_BQ,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
QuantGroupSize,
QuantMode>(stream, group_count, kargs_ptr);
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
QuantGroupSize,
QuantMode>(stream, group_count, kargs_ptr);
std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")";
@@ -259,13 +286,24 @@ int run_grouped_gemm_example_with_layouts(int argc,
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error(
"K must be divisible by QuantGroupSize::kK for AQuantGrouped mode");
}
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
throw std::runtime_error(
"K must be divisible by QuantGroupSize::kK for BQuantGrouped mode");
}
}
@@ -284,6 +322,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
stride_AQs[i] =
ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
stride_BQs[i] = 0; // No B quantization
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
stride_AQs[i] = 0; // No A quantization
@@ -311,10 +355,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout))));
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(0, 0, stride_BQs[i], is_row_major(bq_layout))));
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout))));
ck_tile::host_tensor_descriptor(0, 0, stride_AQs[i], is_row_major(aq_layout))));
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
}
@@ -444,7 +495,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
bq_tensors[i],
c_m_n_host_ref);
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
@@ -452,6 +503,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
AccDataType,
CDataType,
QuantGroupSize,
true>(
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
@@ -477,7 +539,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
return pass;
}
template <typename GemmConfig, typename PrecType, ck_tile::QuantType QuantMode>
template <typename PrecType, ck_tile::QuantType QuantMode, typename GemmConfig>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
@@ -494,6 +556,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
@@ -511,7 +574,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
}
}
template <template <typename PrecType> typename GemmConfig>
template <typename PrecType, ck_tile::QuantType QuantMode>
int run_gemm_example_persistency(
std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[])
{
if(persistent)
{
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, true>;
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
a_layout, b_layout, argc, argv);
}
else
{
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, false>;
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
a_layout, b_layout, argc, argv);
}
}
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -524,29 +604,29 @@ int run_grouped_gemm_example(int argc, char* argv[])
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
std::string quant_mode = arg_parser.get_str("quant_mode");
bool persistent = arg_parser.get_bool("persistent");
if(data_type == "fp8")
{
if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "aquant")
{
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, persistent, argc, argv);
}
else
{
@@ -557,24 +637,23 @@ int run_grouped_gemm_example(int argc, char* argv[])
{
if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "aquant")
{
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, persistent, argc, argv);
}
else
{