mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
update kernel
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "abquant_grouped_gemm.hpp"
|
||||
|
||||
// Non-persistent grouped gemm for ABQuant
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
@@ -33,7 +34,158 @@ template <typename GemmConfig,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
|
||||
ck_tile::QuantType QuantMode>
|
||||
float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
// Persistent grouped gemm tileloop for ABQuant
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
|
||||
@@ -143,6 +143,27 @@ inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gem
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
|
||||
}
|
||||
|
||||
// Forward declaration of the non-persistent version
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
|
||||
float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr);
|
||||
|
||||
// Forward declaration of the tileloop version for persistent kernels
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
@@ -157,8 +178,9 @@ template <typename GemmConfig,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize>
|
||||
float grouped_gemm_abquant_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr);
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr);
|
||||
|
||||
|
||||
@@ -74,54 +74,85 @@ float invoke_abquant_gemm(int n_warmup,
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
// Persistent TileLoop kernel only
|
||||
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
if(args[0].k_batch != 1)
|
||||
if constexpr(!GemmConfig::Persistent)
|
||||
{
|
||||
throw std::runtime_error("Split-K not supported yet for persistent kernel");
|
||||
ave_time =
|
||||
grouped_gemm_abquant<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
|
||||
// the gemm problems known on the host. Instead, we can just pass the pointer
|
||||
// to the kernel and let the workgroups figure out which tiles to work on.
|
||||
// This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
if(args[0].k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("Split-K not supported yet for persistent kernel");
|
||||
}
|
||||
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.aq_ptr,
|
||||
arg.bq_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.QK_A,
|
||||
arg.QK_B,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_E,
|
||||
arg.stride_AQ,
|
||||
arg.stride_BQ,
|
||||
arg.k_batch});
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.aq_ptr,
|
||||
arg.bq_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.QK_A,
|
||||
arg.QK_B,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_E,
|
||||
arg.stride_AQ,
|
||||
arg.stride_BQ,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(stream, group_count, kargs_ptr);
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(stream, group_count, kargs_ptr);
|
||||
|
||||
std::string op_name = "ABQuant Grouped Gemm";
|
||||
|
||||
@@ -426,11 +457,10 @@ int run_abquant_grouped_gemm_example_with_layouts(int argc,
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename PrecType>
|
||||
template <typename PrecType, typename GemmConfig>
|
||||
int run_abquant_grouped_gemm_example_prec_type(std::string a_layout,
|
||||
std::string b_layout,
|
||||
std::string c_layout,
|
||||
[[maybe_unused]] bool persistent,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
@@ -447,9 +477,6 @@ int run_abquant_grouped_gemm_example_prec_type(std::string a_layout,
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
using GemmConfig = typename GemmQuantConfig<ck_tile::QuantType::ABQuantGrouped>::
|
||||
template GemmConfig<PrecType, true>;
|
||||
|
||||
// Support RCR, RRR, CRR layouts
|
||||
if(a_layout == "R" && b_layout == "C" && c_layout == "R")
|
||||
{
|
||||
@@ -496,6 +523,30 @@ int run_abquant_grouped_gemm_example_prec_type(std::string a_layout,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PrecType>
|
||||
int run_abquant_gemm_example_persistency(std::string a_layout,
|
||||
std::string b_layout,
|
||||
std::string c_layout,
|
||||
bool persistent,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
if(persistent)
|
||||
{
|
||||
using GemmConfig = typename GemmQuantConfig<ck_tile::QuantType::ABQuantGrouped>::
|
||||
template GemmConfig<PrecType, true>;
|
||||
return run_abquant_grouped_gemm_example_prec_type<PrecType, GemmConfig>(
|
||||
a_layout, b_layout, c_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
using GemmConfig = typename GemmQuantConfig<ck_tile::QuantType::ABQuantGrouped>::
|
||||
template GemmConfig<PrecType, false>;
|
||||
return run_abquant_grouped_gemm_example_prec_type<PrecType, GemmConfig>(
|
||||
a_layout, b_layout, c_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
|
||||
int run_abquant_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -508,7 +559,7 @@ int run_abquant_grouped_gemm_example(int argc, char* argv[])
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string c_layout = arg_parser.get_str("c_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
const bool persistent = arg_parser.get_bool("persistent");
|
||||
bool persistent = arg_parser.get_bool("persistent");
|
||||
|
||||
// Validate layout combinations
|
||||
if(!((a_layout == "R" && b_layout == "C" && c_layout == "R") ||
|
||||
@@ -522,12 +573,12 @@ int run_abquant_grouped_gemm_example(int argc, char* argv[])
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
return run_abquant_grouped_gemm_example_prec_type<ck_tile::fp8_t>(
|
||||
return run_abquant_gemm_example_persistency<ck_tile::fp8_t>(
|
||||
a_layout, b_layout, c_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_abquant_grouped_gemm_example_prec_type<ck_tile::bf8_t>(
|
||||
return run_abquant_gemm_example_persistency<ck_tile::bf8_t>(
|
||||
a_layout, b_layout, c_layout, persistent, argc, argv);
|
||||
}
|
||||
else
|
||||
|
||||
@@ -17,7 +17,10 @@ endif()
|
||||
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
|
||||
# target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
|
||||
# target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
# endif()
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_abquant test_grouped_gemm_quant_abquant.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_abquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util_quant.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using ABQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_ABQuant = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, True, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, ABQuant, False, True, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, False, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, ABQuant, False, False, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_ABQuant, KernelTypes_ABQuant);
|
||||
|
||||
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_ABQuant
|
||||
#include "test_grouped_gemm_quant_ut_cases.inc"
|
||||
#undef TEST_CLASS_NAME
|
||||
|
||||
@@ -85,7 +85,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped;
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantType == ck_tile::QuantType::ABQuantGrouped;
|
||||
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
@@ -168,17 +169,32 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
ADataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>,
|
||||
std::conditional_t<QuantType == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
ADataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
QuantGroupSize,
|
||||
TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -196,9 +212,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
std::conditional_t<
|
||||
QuantType == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
|
||||
std::conditional_t<
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped,
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
@@ -309,7 +328,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped;
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantType == ck_tile::QuantType::ABQuantGrouped;
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantType == ck_tile::QuantType::AQuantGrouped,
|
||||
@@ -321,13 +341,24 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
TransposeC>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize>>,
|
||||
std::conditional_t<QuantType == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize>,
|
||||
ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
QuantGroupSize,
|
||||
TransposeC>>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -343,9 +374,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
std::conditional_t<
|
||||
QuantType == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
|
||||
std::conditional_t<
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped,
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
@@ -494,6 +528,16 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
"K must be divisible by QuantGroupSize::kK for BQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
AQK = K / QuantGroupSize::kK; // Group quantization for A: AQK = K / GroupSize
|
||||
BQK = K / QuantGroupSize::kK; // Group quantization for B: BQK = K / GroupSize
|
||||
if(K % QuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by QuantGroupSize::kK for ABQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(ALayout{}));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(BLayout{}));
|
||||
@@ -522,6 +566,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
stride_BQs[i] =
|
||||
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout()));
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
stride_AQs[i] =
|
||||
ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout()));
|
||||
stride_BQs[i] =
|
||||
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout()));
|
||||
}
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(ALayout{}))));
|
||||
@@ -565,6 +616,15 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
|
||||
BQK, N, stride_BQs[i], is_row_major(BQLayout()))));
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
aq_tensors.push_back(
|
||||
ck_tile::HostTensor<AQDataType>(ck_tile::host_tensor_descriptor(
|
||||
M, AQK, stride_AQs[i], is_row_major(AQLayout{}))));
|
||||
bq_tensors.push_back(
|
||||
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
|
||||
BQK, N, stride_BQs[i], is_row_major(BQLayout()))));
|
||||
}
|
||||
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
@@ -750,6 +810,18 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
false>(
|
||||
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantType == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_abquant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
QuantGroupSize>(
|
||||
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
@@ -782,3 +854,6 @@ using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_ABQuant = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
Reference in New Issue
Block a user