mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
[CK_TILE] Tensor-wise scaled quant gemm kernel (#2846)
* rename gemm_group_quant to gemm_quant
* Add TensorWise quant mode
* Cshuffle epilogue tests with tensor scaling
* Add tensor quant to example
* Don't use readfirstlane for reading scales - doesn't work for some reason
* Add to changelog
* revert include - from a merge problem?
* revert common.hpp include
* revert host.hpp include
* remove unused utility function
* rename quant pipeline problem
* refactor quant tests
* remove aquant utils
* use TEST_F
* fix all tests by changing gemm config
* Use typed tests
* fix copyright
[ROCm/composable_kernel commit: 4363a82bd6]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -158,4 +158,7 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<Windo
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Tile>
|
||||
concept IsLoadableTile = requires { load_tile(std::declval<Tile>()); };
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -180,10 +180,6 @@ CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
@@ -198,7 +194,57 @@ CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<AQDataType>& aq_1_1,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
const HostTensor<BQDataType>& bq_1_1,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
|
||||
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
|
||||
static_assert(std::is_same_v<AccDataType, float>);
|
||||
static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
|
||||
static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
// Init accumulator
|
||||
AccDataType v_acc = 0;
|
||||
// Get scale for A and scale for B
|
||||
const AccDataType a_scale = ck_tile::type_convert<AccDataType>(aq_1_1(0, 0));
|
||||
const AccDataType b_scale = ck_tile::type_convert<AccDataType>(bq_1_1(0, 0));
|
||||
|
||||
// Compute the dot product
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
AccDataType v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
v_acc = v_acc * a_scale * b_scale;
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
|
||||
@@ -304,22 +304,41 @@ struct CShuffleEpilogue
|
||||
CK_TILE_DEVICE void
|
||||
scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
|
||||
{
|
||||
// Load tiles
|
||||
const auto scale_m_tile = load_tile(scale_m_window);
|
||||
const auto scale_n_tile = load_tile(scale_n_window);
|
||||
|
||||
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
|
||||
tile_elementwise_inout(
|
||||
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
|
||||
|
||||
// Move scale windows
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
// Check if scales are EmptyScale first (no scaling needed)
|
||||
if constexpr(std::is_same_v<ScaleM, EmptyScale> && std::is_same_v<ScaleN, EmptyScale>)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
// No scaling needed - this is a no-op
|
||||
}
|
||||
// Check if scales are scalar AccDataType
|
||||
else if constexpr(std::is_same_v<ScaleM, AccDataType> &&
|
||||
std::is_same_v<ScaleN, AccDataType>)
|
||||
{
|
||||
// Handle scalar scales
|
||||
const AccDataType scale_m = scale_m_window;
|
||||
const AccDataType scale_n = scale_n_window;
|
||||
tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; },
|
||||
lds_tile);
|
||||
}
|
||||
// Otherwise, assume they are tile windows that can be loaded
|
||||
else
|
||||
{
|
||||
// Load tiles
|
||||
const auto scale_m_tile = load_tile(scale_m_window);
|
||||
const auto scale_n_tile = load_tile(scale_n_window);
|
||||
|
||||
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
|
||||
tile_elementwise_inout(
|
||||
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
|
||||
|
||||
// Move scale windows
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -452,6 +471,8 @@ struct CShuffleEpilogue
|
||||
// Optional scales (must share the same distribution to match per-thread indexing)
|
||||
constexpr bool has_scales =
|
||||
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
|
||||
constexpr bool has_scalar_scales =
|
||||
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
|
||||
|
||||
// Tiles to hold row/col scales when present
|
||||
using SMType = typename GetDataType<remove_cvref_t<ScaleM>>::type;
|
||||
@@ -462,8 +483,11 @@ struct CShuffleEpilogue
|
||||
|
||||
// Build windows only if scales are provided
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
|
||||
"ScaleM must be a loadable tile");
|
||||
return make_tile_window(scale_m, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
@@ -472,8 +496,11 @@ struct CShuffleEpilogue
|
||||
}
|
||||
}();
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
|
||||
"ScaleN must be a loadable tile");
|
||||
return make_tile_window(scale_n, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
@@ -489,7 +516,7 @@ struct CShuffleEpilogue
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
// If scales provided, load them with identical distribution
|
||||
if constexpr(has_scales)
|
||||
if constexpr(has_scales && IsLoadableTile<ScaleM> && IsLoadableTile<ScaleN>)
|
||||
{
|
||||
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
|
||||
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
|
||||
@@ -504,7 +531,11 @@ struct CShuffleEpilogue
|
||||
auto emit = [&](index_t out_idx, index_t src_row) {
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
|
||||
|
||||
if constexpr(has_scales)
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
v = static_cast<AccDataType>(v * scale_m * scale_n);
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
{
|
||||
// same linear index mapping on the permuted distribution
|
||||
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
|
||||
@@ -595,10 +626,19 @@ struct CShuffleEpilogue
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr bool has_scales =
|
||||
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
|
||||
!std::is_same_v<ScaleM, EmptyScale> && !std::is_same_v<ScaleN, EmptyScale>;
|
||||
constexpr bool has_scalar_scales =
|
||||
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
return scale_m;
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
|
||||
"ScaleM must be a loadable tile");
|
||||
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
@@ -607,8 +647,15 @@ struct CShuffleEpilogue
|
||||
}
|
||||
}();
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
return scale_n;
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
|
||||
"ScaleN must be a loadable tile");
|
||||
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
21
include/ck_tile/ops/gemm_quant.hpp
Normal file
21
include/ck_tile/ops/gemm_quant.hpp
Normal file
@@ -0,0 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -330,7 +330,6 @@ struct QuantGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: no kernel currently uses BQuant like this:
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
@@ -890,6 +889,7 @@ struct QuantGemmKernel
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param aq_ptr input AQ pointer
|
||||
* @param bq_ptr input BQ pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
@@ -938,7 +938,8 @@ struct QuantGemmKernel
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant ||
|
||||
kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
@@ -964,6 +965,18 @@ struct QuantGemmKernel
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
// TODO: why doesn't readfirstlane work here?
|
||||
// const AccDataType aq_scale =
|
||||
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*aq_ptr));
|
||||
// const AccDataType bq_scale =
|
||||
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*bq_ptr));
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -168,17 +168,18 @@ template <typename ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
using GemmRowColQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
AccDataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
1, // no group size applicable
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
using GemmRowColTensorQuantPipelineProblem =
|
||||
GemmQuantPipelineProblemBase<ADataType_,
|
||||
AccDataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
1, // no group size applicable
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
} // namespace ck_tile
|
||||
@@ -12,9 +12,22 @@ enum struct QuantType : std::uint16_t
|
||||
{
|
||||
AQuantGrouped = 0,
|
||||
BQuantGrouped = 1,
|
||||
RowColQuant = 2
|
||||
RowColQuant = 2,
|
||||
TensorQuant = 3
|
||||
};
|
||||
|
||||
std::string quant_type_to_string(QuantType quant_type)
|
||||
{
|
||||
switch(quant_type)
|
||||
{
|
||||
case QuantType::AQuantGrouped: return "AQuantGrouped";
|
||||
case QuantType::BQuantGrouped: return "BQuantGrouped";
|
||||
case QuantType::RowColQuant: return "RowColQuant";
|
||||
case QuantType::TensorQuant: return "TensorQuant";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
Reference in New Issue
Block a user