mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK Tile] Grouped GEMM aquant mode and non-persistent kernel (#3337)
* wip: add aquant to grouped gemm quant example
* fix: properly handle hot loop count in aquant pipeline
* fix: add separate GemmConfig structs for AQuant, automatically select the correct one
* feat: finish support for a non-persistent kernel invocation for grouped gemm quant, and add support code to example
* refactor: cleaned up grouped gemm quant example a bit by reusing pipeline selection logic
* chore: add warp gemm dispatchers for a couple of TransposeC K=32 variants
* feat: add quant grouped gemm tests cases for aquant (regular and transpose C) and non-persistent kernel
* fix: update base pipeline classes according to changes in develop branch
* Revert "chore: add warp gemm dispatchers for a couple of TransposeC K=32 variants"
This reverts commit b3fd4d326d.
* feat: remove aquant config from grouped gemm quant example, update to add persistency as runtime parameter
* chore: removed work-around for aquant bug that has been fixed
* chore: fix typo in command-line parameters
* fix: correct K warp tile size for gfx950
* chore: incorrect warp tile configuration on gfx942
This commit is contained in:
@@ -9,14 +9,190 @@
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "quant_grouped_gemm.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
ADataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
@@ -59,41 +235,48 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
true>; // Persistence
|
||||
GemmConfig::Persistent>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using QuantGemmProblem = typename std::conditional<
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize>,
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
GemmConfig::TransposeC>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler>>::type;
|
||||
scheduler>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
@@ -146,6 +329,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
int result1 = !run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
|
||||
int result1 = run_grouped_gemm_example(argc, argv);
|
||||
return result1;
|
||||
}
|
||||
|
||||
@@ -64,6 +64,7 @@ struct GemmTypeConfig<ck_tile::bf8_t>
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <bool Persistent_>
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
@@ -83,10 +84,11 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool Persistent = Persistent_;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
template <typename PrecType, bool Persistent>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -101,8 +103,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
template <typename PrecType, bool Persistent>
|
||||
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -121,6 +123,66 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <ck_tile::QuantType QuantMode>
|
||||
struct GemmQuantConfig;
|
||||
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::TensorQuant>
|
||||
{
|
||||
template <typename PrecType, bool Persistent>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::RowColQuant>
|
||||
{
|
||||
template <typename PrecType, bool Persistent>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::AQuantGrouped>
|
||||
{
|
||||
template <typename PrecType, bool Persistent>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::BQuantGrouped>
|
||||
{
|
||||
template <typename PrecType, bool Persistent>
|
||||
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<PrecType, Persistent>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<GemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<GemmProblem>>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline =
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
@@ -148,8 +210,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol")
|
||||
.insert("init", "0", "0. Random, 2. One(s) (Constant)");
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol")
|
||||
.insert("init", "0", "0. Random, 2. One(s) (Constant)")
|
||||
.insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -57,56 +57,83 @@ float invoke_gemm(int n_warmup,
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
|
||||
// the gemm problems known on the host. Instead, we can just pass the pointer
|
||||
// to the kernel and let the workgroups figure out which tiles to work on.
|
||||
// This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
assert(args[0].k_batch == 1);
|
||||
for(const auto& arg : args)
|
||||
if constexpr(!GemmConfig::Persistent)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.aq_ptr,
|
||||
arg.bq_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.QK_A,
|
||||
arg.QK_B,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_E,
|
||||
arg.stride_AQ,
|
||||
arg.stride_BQ,
|
||||
arg.k_batch});
|
||||
ave_time =
|
||||
grouped_gemm<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
|
||||
// the gemm problems known on the host. Instead, we can just pass the pointer
|
||||
// to the kernel and let the workgroups figure out which tiles to work on.
|
||||
// This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
if(args[0].k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("Split-K not supported yet for persistent kernel");
|
||||
}
|
||||
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.aq_ptr,
|
||||
arg.bq_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.QK_A,
|
||||
arg.QK_B,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_E,
|
||||
arg.stride_AQ,
|
||||
arg.stride_BQ,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(stream, group_count, kargs_ptr);
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(stream, group_count, kargs_ptr);
|
||||
|
||||
std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")";
|
||||
|
||||
@@ -259,13 +286,24 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
|
||||
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
|
||||
BQK = 0; // No B quantization
|
||||
if(K % QuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by QuantGroupSize::kK for AQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
|
||||
if(K % QuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by QuantGroupSize::kK for BQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,6 +322,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
stride_AQs[i] =
|
||||
ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
|
||||
stride_BQs[i] = 0; // No B quantization
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
stride_AQs[i] = 0; // No A quantization
|
||||
@@ -311,10 +355,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(0, 0, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
ck_tile::host_tensor_descriptor(0, 0, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
@@ -444,7 +495,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_tensors[i],
|
||||
c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
@@ -452,6 +503,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
true>(
|
||||
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
BQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(
|
||||
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
@@ -477,7 +539,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename PrecType, ck_tile::QuantType QuantMode>
|
||||
template <typename PrecType, ck_tile::QuantType QuantMode, typename GemmConfig>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
@@ -494,6 +556,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
@@ -511,7 +574,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
template <typename PrecType, ck_tile::QuantType QuantMode>
|
||||
int run_gemm_example_persistency(
|
||||
std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[])
|
||||
{
|
||||
if(persistent)
|
||||
{
|
||||
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, true>;
|
||||
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, false>;
|
||||
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
|
||||
int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -524,29 +604,29 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
std::string quant_mode = arg_parser.get_str("quant_mode");
|
||||
bool persistent = arg_parser.get_bool("persistent");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -557,24 +637,23 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -163,7 +163,6 @@ struct QuantGroupedGemmKernel
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
|
||||
static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -262,10 +261,9 @@ struct QuantGroupedGemmKernel
|
||||
auto karg =
|
||||
QuantGroupedGemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
|
||||
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
|
||||
type_convert<CDataType*>(gemm_descs[i].e_ptr),
|
||||
type_convert<const AQDataType*>(gemm_descs[i].aq_ptr),
|
||||
type_convert<const BQDataType*>(gemm_descs[i].bq_ptr),
|
||||
gemm_descs[i].k_batch,
|
||||
type_convert<CDataType*>(gemm_descs[i].e_ptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@@ -275,7 +273,8 @@ struct QuantGroupedGemmKernel
|
||||
stride_b,
|
||||
stride_e,
|
||||
gemm_descs[i].stride_AQ,
|
||||
gemm_descs[i].stride_BQ};
|
||||
gemm_descs[i].stride_BQ,
|
||||
gemm_descs[i].k_batch};
|
||||
|
||||
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
|
||||
}
|
||||
@@ -342,16 +341,32 @@ struct QuantGroupedGemmKernel
|
||||
else
|
||||
{
|
||||
|
||||
RunGemmWithPipelineSelection(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
RunGemmWithPipelineSelection(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else // Non-persistent kernel
|
||||
{
|
||||
Base::RunGemm({a_ptr},
|
||||
{b_ptr},
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -451,7 +466,24 @@ struct QuantGroupedGemmKernel
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
aq_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
// Run GEMM pipeline
|
||||
@@ -496,6 +528,53 @@ struct QuantGroupedGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t FindGroupId(const QuantGemmTransKernelArg* gemm_desc_ptr,
|
||||
index_t block_id,
|
||||
index_t group_count) const
|
||||
{
|
||||
index_t left = 0;
|
||||
index_t right = group_count;
|
||||
index_t group_id = index_t((left + right) >> 1);
|
||||
|
||||
while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
|
||||
block_id < gemm_desc_ptr[group_id].block_end)) &&
|
||||
left <= right)
|
||||
{
|
||||
if(block_id < gemm_desc_ptr[group_id].block_start)
|
||||
{
|
||||
right = group_id;
|
||||
}
|
||||
else
|
||||
{
|
||||
left = group_id;
|
||||
}
|
||||
group_id = index_t((left + right) >> 1);
|
||||
}
|
||||
|
||||
return group_id;
|
||||
}
|
||||
|
||||
// For non-persistent kernels
|
||||
template <bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
|
||||
CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
index_t group_count) const
|
||||
{
|
||||
const index_t block_id = ck_tile::get_block_1d_id();
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const QuantGemmTransKernelArg*>(
|
||||
cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
|
||||
const auto& kargs = gemm_desc_ptr[group_id];
|
||||
|
||||
const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
|
||||
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
0,
|
||||
kargs.group_karg.M,
|
||||
kargs.group_karg.N,
|
||||
(block_id - kargs.block_start) % grid_size_2d);
|
||||
Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
|
||||
}
|
||||
|
||||
// For persistent kernels
|
||||
template <bool U = UsePersistentKernel,
|
||||
typename = std::enable_if_t<U>,
|
||||
|
||||
@@ -319,6 +319,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
constexpr index_t tail_count =
|
||||
((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) ? 1 : 2;
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
@@ -366,7 +368,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
} while(i < (num_loop - tail_count));
|
||||
}
|
||||
// tail
|
||||
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
|
||||
@@ -439,6 +441,51 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/// @brief Runtime pipeline dispatch operator for grouped GEMM kernels.
|
||||
///
|
||||
/// This operator is used by grouped GEMM kernels where pipeline parameters
|
||||
/// (has_hot_loop, num_loop, tail_number) are calculated on the device side
|
||||
/// at runtime, not on the host side during compilation. This is necessary
|
||||
/// because different GEMM problems in the group may have different K dimensions,
|
||||
/// requiring different pipeline configurations that cannot be determined at
|
||||
/// compile time.
|
||||
///
|
||||
/// @param a_dram_block_window_tmp Block window for A tensor in DRAM
|
||||
/// @param b_dram_block_window_tmp Block window for B tensor in DRAM
|
||||
/// @param aq_dram_block_window_tmp Block window for AQ (quantization scale) tensor in DRAM
|
||||
/// @param num_loop Number of main loop iterations (calculated on device)
|
||||
/// @param has_hot_loop Whether the pipeline has a hot loop (calculated on device)
|
||||
/// @param tail_number Type of tail handling required (calculated on device)
|
||||
/// @param p_smem Pointer to shared memory
|
||||
/// @return Accumulated result tile in registers
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem,
|
||||
index_t m = 0) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) {
|
||||
constexpr bool hot_loop = has_hot_loop_.value;
|
||||
constexpr auto tail_num = tail_number_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
aq_dram_block_window_tmp,
|
||||
m, // dummy value, won't be used
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -14,6 +14,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
@@ -18,32 +18,41 @@ using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
|
||||
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
|
||||
using AQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util_quant.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using AQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_AQuant = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, True>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_AQuant, KernelTypes_AQuant);
|
||||
|
||||
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_AQuant
|
||||
#include "test_grouped_gemm_quant_ut_cases.inc"
|
||||
#undef TEST_CLASS_NAME
|
||||
@@ -20,9 +20,14 @@ using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQ
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_BQuant = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, False, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -20,11 +20,14 @@ using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantTyp
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_RowCol = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -20,11 +20,14 @@ using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantTyp
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_Tensor = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
#include <sstream>
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
@@ -32,24 +33,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using AQLayout = Row;
|
||||
using BQLayout = Col;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value;
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
static constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value;
|
||||
static constexpr bool TransposeC = std::tuple_element_t<12, Tuple>::value;
|
||||
|
||||
struct GroupedGemKernelParam_Mfma
|
||||
{
|
||||
@@ -66,11 +52,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
static const ck_tile::index_t N_Warp = 2;
|
||||
static const ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 32;
|
||||
static const ck_tile::index_t N_Warp_Tile = 32;
|
||||
static const ck_tile::index_t K_Warp_Tile =
|
||||
TestCkTileGroupedGemmQuant::template get_k_from_preshuffled_warp_tile<BDataType,
|
||||
M_Warp_Tile>();
|
||||
static const ck_tile::index_t M_Warp_Tile = 16;
|
||||
static const ck_tile::index_t N_Warp_Tile = 16;
|
||||
static const ck_tile::index_t K_Warp_Tile = 32;
|
||||
};
|
||||
|
||||
struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma
|
||||
@@ -90,16 +74,201 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
|
||||
float invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr bool DoubleSmemBuffer =
|
||||
PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B
|
||||
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped;
|
||||
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
|
||||
GroupedGemKernelParam::N_Tile,
|
||||
GroupedGemKernelParam::K_Tile>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::K_Warp>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
false,
|
||||
PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantType,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
TransposeC,
|
||||
DoubleSmemBuffer,
|
||||
Persistent>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<
|
||||
QuantType == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
std::conditional_t<
|
||||
PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile;
|
||||
const ck_tile::index_t K_split =
|
||||
(gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
|
||||
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantType == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
ADataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<
|
||||
QuantType == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName()
|
||||
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
|
||||
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GroupedGemKernelParam::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr bool DoubleSmemBuffer =
|
||||
PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
@@ -131,40 +300,53 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
BQLayout,
|
||||
TransposeC,
|
||||
DoubleSmemBuffer,
|
||||
true>;
|
||||
Persistent>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using QuantGemmProblem = typename std::conditional<
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize>,
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped;
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantType == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
TransposeC>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
TransposeC,
|
||||
BDataType,
|
||||
scheduler>>::type;
|
||||
scheduler>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantType == ck_tile::QuantType::RowColQuant ||
|
||||
QuantType == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<
|
||||
QuantType == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -199,7 +381,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<kBlockPerCu>(
|
||||
ck_tile::make_kernel<GroupedGemKernelParam::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
@@ -292,13 +474,24 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
|
||||
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
|
||||
BQK = 0; // No B quantization
|
||||
if(K % QuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by QuantGroupSize::kK for AQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / 128; // Group quantization: BQK = K / GroupSize
|
||||
if(K % 128 != 0)
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
|
||||
if(K % QuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by QuantGroupSize::kK for BQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -317,6 +510,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
stride_AQs[i] =
|
||||
ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout()));
|
||||
stride_BQs[i] = 0; // No B quantization
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
stride_AQs[i] = 0; // No A quantization
|
||||
@@ -348,11 +547,20 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
|
||||
1, 1, stride_BQs[i], is_row_major(BQLayout()))));
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
aq_tensors.push_back(
|
||||
ck_tile::HostTensor<AQDataType>(ck_tile::host_tensor_descriptor(
|
||||
M, AQK, stride_AQs[i], is_row_major(AQLayout{}))));
|
||||
bq_tensors.push_back(
|
||||
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
|
||||
0, 0, stride_BQs[i], is_row_major(BQLayout()))));
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
aq_tensors.push_back(
|
||||
ck_tile::HostTensor<AQDataType>(ck_tile::host_tensor_descriptor(
|
||||
0, AQK, stride_AQs[i], is_row_major(AQLayout{}))));
|
||||
0, 0, stride_AQs[i], is_row_major(AQLayout{}))));
|
||||
bq_tensors.push_back(
|
||||
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
|
||||
BQK, N, stride_BQs[i], is_row_major(BQLayout()))));
|
||||
@@ -429,11 +637,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
|
||||
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
// Generate kernel arguments
|
||||
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
assert(gemm_descs[0].k_batch == 1);
|
||||
for(const auto& arg : gemm_descs)
|
||||
{
|
||||
@@ -471,7 +680,14 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
}
|
||||
else
|
||||
{
|
||||
GTEST_FAIL() << "Non-persistent kernel not implemented yet";
|
||||
const auto stream = ck_tile::stream_config{nullptr, false, 1};
|
||||
#if CK_TILE_USE_WMMA
|
||||
invoke_grouped_gemm<GroupedGemKernelParam_Wmma, ALayout, BLayout, CLayout>(
|
||||
gemm_descs, stream, kargs_ptr);
|
||||
#else
|
||||
invoke_grouped_gemm<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
|
||||
gemm_descs, stream, kargs_ptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Copy results back to host for validation
|
||||
@@ -512,7 +728,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
bq_tensors[i],
|
||||
c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
|
||||
else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
@@ -520,6 +736,17 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
true>(
|
||||
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
BQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(
|
||||
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
@@ -550,5 +777,8 @@ using TestCkTileGroupedGemmQuant_RowCol = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_Tensor = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
Reference in New Issue
Block a user