update kernel

This commit is contained in:
kyle-256
2025-12-16 07:20:58 +00:00
parent 2198f8b583
commit 8fafd6db2f
6 changed files with 428 additions and 86 deletions

View File

@@ -19,6 +19,7 @@
#include "ck_tile/host.hpp"
#include "abquant_grouped_gemm.hpp"
// Non-persistent grouped gemm for ABQuant
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
@@ -33,7 +34,158 @@ template <typename GemmConfig,
typename CDataType,
typename AQuantGroupSize,
typename BQuantGroupSize,
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
ck_tile::QuantType QuantMode>
float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode,
AQLayout,
BQLayout,
GemmConfig::TransposeC,
GemmConfig::DoubleSmemBuffer,
GemmConfig::Persistent>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
GemmConfig::PreshuffleB>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
AQuantGroupSize,
BQuantGroupSize,
GemmConfig::TransposeC,
BDataType,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
// Persistent grouped gemm tileloop for ABQuant
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename AQuantGroupSize,
typename BQuantGroupSize,
ck_tile::QuantType QuantMode>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr)

View File

@@ -143,6 +143,27 @@ inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gem
{
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
}
// Forward declaration of the non-persistent version
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename AQuantGroupSize,
typename BQuantGroupSize,
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr);
// Forward declaration of the tileloop version for persistent kernels
template <typename GemmConfig,
typename ALayout,
@@ -157,8 +178,9 @@ template <typename GemmConfig,
typename AccDataType,
typename CDataType,
typename AQuantGroupSize,
typename BQuantGroupSize>
float grouped_gemm_abquant_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr);
typename BQuantGroupSize,
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr);

View File

@@ -74,54 +74,85 @@ float invoke_abquant_gemm(int n_warmup,
float ave_time = 0;
// Persistent TileLoop kernel only
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
if(args[0].k_batch != 1)
if constexpr(!GemmConfig::Persistent)
{
throw std::runtime_error("Split-K not supported yet for persistent kernel");
ave_time =
grouped_gemm_abquant<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
gemm_workspace.GetDeviceBuffer());
}
else
{
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
// the gemm problems known on the host. Instead, we can just pass the pointer
// to the kernel and let the workgroups figure out which tiles to work on.
// This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
if(args[0].k_batch != 1)
{
throw std::runtime_error("Split-K not supported yet for persistent kernel");
}
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.aq_ptr,
arg.bq_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.QK_A,
arg.QK_B,
arg.stride_A,
arg.stride_B,
arg.stride_E,
arg.stride_AQ,
arg.stride_BQ,
arg.k_batch});
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.aq_ptr,
arg.bq_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.QK_A,
arg.QK_B,
arg.stride_A,
arg.stride_B,
arg.stride_E,
arg.stride_AQ,
arg.stride_BQ,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(stream, group_count, kargs_ptr);
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(stream, group_count, kargs_ptr);
std::string op_name = "ABQuant Grouped Gemm";
@@ -426,11 +457,10 @@ int run_abquant_grouped_gemm_example_with_layouts(int argc,
return pass;
}
template <typename PrecType>
template <typename PrecType, typename GemmConfig>
int run_abquant_grouped_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
std::string c_layout,
[[maybe_unused]] bool persistent,
int argc,
char* argv[])
{
@@ -447,9 +477,6 @@ int run_abquant_grouped_gemm_example_prec_type(std::string a_layout,
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using GemmConfig = typename GemmQuantConfig<ck_tile::QuantType::ABQuantGrouped>::
template GemmConfig<PrecType, true>;
// Support RCR, RRR, CRR layouts
if(a_layout == "R" && b_layout == "C" && c_layout == "R")
{
@@ -496,6 +523,30 @@ int run_abquant_grouped_gemm_example_prec_type(std::string a_layout,
}
}
template <typename PrecType>
int run_abquant_gemm_example_persistency(std::string a_layout,
std::string b_layout,
std::string c_layout,
bool persistent,
int argc,
char* argv[])
{
if(persistent)
{
using GemmConfig = typename GemmQuantConfig<ck_tile::QuantType::ABQuantGrouped>::
template GemmConfig<PrecType, true>;
return run_abquant_grouped_gemm_example_prec_type<PrecType, GemmConfig>(
a_layout, b_layout, c_layout, argc, argv);
}
else
{
using GemmConfig = typename GemmQuantConfig<ck_tile::QuantType::ABQuantGrouped>::
template GemmConfig<PrecType, false>;
return run_abquant_grouped_gemm_example_prec_type<PrecType, GemmConfig>(
a_layout, b_layout, c_layout, argc, argv);
}
}
int run_abquant_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -508,7 +559,7 @@ int run_abquant_grouped_gemm_example(int argc, char* argv[])
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string c_layout = arg_parser.get_str("c_layout");
const std::string data_type = arg_parser.get_str("prec");
const bool persistent = arg_parser.get_bool("persistent");
bool persistent = arg_parser.get_bool("persistent");
// Validate layout combinations
if(!((a_layout == "R" && b_layout == "C" && c_layout == "R") ||
@@ -522,12 +573,12 @@ int run_abquant_grouped_gemm_example(int argc, char* argv[])
if(data_type == "fp8")
{
return run_abquant_grouped_gemm_example_prec_type<ck_tile::fp8_t>(
return run_abquant_gemm_example_persistency<ck_tile::fp8_t>(
a_layout, b_layout, c_layout, persistent, argc, argv);
}
else if(data_type == "bf8")
{
return run_abquant_grouped_gemm_example_prec_type<ck_tile::bf8_t>(
return run_abquant_gemm_example_persistency<ck_tile::bf8_t>(
a_layout, b_layout, c_layout, persistent, argc, argv);
}
else

View File

@@ -17,7 +17,10 @@ endif()
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
# target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
# target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# endif()
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_grouped_gemm_quant_abquant test_grouped_gemm_quant_abquant.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_abquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -0,0 +1,39 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util_quant.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
using ABQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
// clang-format off
using KernelTypes_ABQuant = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, ABQuant, False, True, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, False, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, False, False>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_ABQuant, KernelTypes_ABQuant);
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_ABQuant
#include "test_grouped_gemm_quant_ut_cases.inc"
#undef TEST_CLASS_NAME

View File

@@ -85,7 +85,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped ||
QuantType == ck_tile::QuantType::BQuantGrouped;
QuantType == ck_tile::QuantType::BQuantGrouped ||
QuantType == ck_tile::QuantType::ABQuantGrouped;
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
@@ -168,17 +169,32 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
scheduler,
has_hot_loop_v,
tail_number_v>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
ADataType,
scheduler,
has_hot_loop_v,
tail_number_v>>,
std::conditional_t<QuantType == ck_tile::QuantType::BQuantGrouped,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
ADataType,
scheduler,
has_hot_loop_v,
tail_number_v>,
ck_tile::GemmABQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
QuantGroupSize,
TransposeC,
BDataType,
scheduler,
has_hot_loop_v,
tail_number_v>>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
@@ -196,9 +212,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
std::conditional_t<
QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
std::conditional_t<
QuantType == ck_tile::QuantType::BQuantGrouped,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
@@ -309,7 +328,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
// These are automatically run inside the kernel based on the given input data.
constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped ||
QuantType == ck_tile::QuantType::BQuantGrouped;
QuantType == ck_tile::QuantType::BQuantGrouped ||
QuantType == ck_tile::QuantType::ABQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantType == ck_tile::QuantType::AQuantGrouped,
@@ -321,13 +341,24 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
GemmUniversalTraits,
QuantGroupSize,
TransposeC>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>>,
std::conditional_t<QuantType == ck_tile::QuantType::BQuantGrouped,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>,
ck_tile::GemmABQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
QuantGroupSize,
TransposeC>>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
@@ -343,9 +374,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
std::conditional_t<
QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
std::conditional_t<
QuantType == ck_tile::QuantType::BQuantGrouped,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -494,6 +528,16 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
"K must be divisible by QuantGroupSize::kK for BQuantGrouped mode");
}
}
else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped)
{
AQK = K / QuantGroupSize::kK; // Group quantization for A: AQK = K / GroupSize
BQK = K / QuantGroupSize::kK; // Group quantization for B: BQK = K / GroupSize
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error(
"K must be divisible by QuantGroupSize::kK for ABQuantGrouped mode");
}
}
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(ALayout{}));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(BLayout{}));
@@ -522,6 +566,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
stride_BQs[i] =
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout()));
}
else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped)
{
stride_AQs[i] =
ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout()));
stride_BQs[i] =
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout()));
}
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(ALayout{}))));
@@ -565,6 +616,15 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
BQK, N, stride_BQs[i], is_row_major(BQLayout()))));
}
else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped)
{
aq_tensors.push_back(
ck_tile::HostTensor<AQDataType>(ck_tile::host_tensor_descriptor(
M, AQK, stride_AQs[i], is_row_major(AQLayout{}))));
bq_tensors.push_back(
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
BQK, N, stride_BQs[i], is_row_major(BQLayout()))));
}
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
<< " b_k_n: " << b_k_n_tensors[i].mDesc
@@ -750,6 +810,18 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
false>(
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped)
{
ck_tile::reference_gemm_abquant<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
QuantGroupSize,
QuantGroupSize>(
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref);
}
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
@@ -782,3 +854,6 @@ using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_ABQuant = TestCkTileGroupedGemmQuant<Tuple>;