[CK_TILE] Tensor-wise scaled quant gemm kernel (#2846)

* rename gemm_group_quant to gemm_quant

* Add TensorWise quant mode

* Cshuffle epilogue tests with tensor scaling

* Add tensor quant to example

* Don't use readfirstlane for reading scales - doesn't work for some reason

* Add to changelog

* revert include - from a merge problem?

* revert common.hpp include

* revert host.hpp include

* remove unused utility function

* rename quant pipeline problem

* refactor quant tests

* remove aquant utils

* use TEST_F

* fix all tests by changing gemm config

* Use typed tests

* fix copyright

[ROCm/composable_kernel commit: 4363a82bd6]
This commit is contained in:
Sami Remes
2025-09-20 02:52:35 +03:00
committed by GitHub
parent ee43f0f0be
commit 8d2a444c55
39 changed files with 1555 additions and 1056 deletions

View File

@@ -13,7 +13,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
#include "ck_tile/host.hpp"
#include "quant_grouped_gemm.hpp"
@@ -65,15 +65,15 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
constexpr auto memory_operation = memory_operation_.value;
constexpr bool transpose_c = false;
using QuantGemmProblem = ck_tile::GemmRowColQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>;
using QuantGemmProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;

View File

@@ -5,6 +5,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
- BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline
- Row and Column-wise scaled: scaling implemented in Epilogue
- Tensor-wise scaled: scaling implemented in Epilogue
## build
```
@@ -14,7 +15,6 @@ mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch>
# Compile the quant kernels
make tile_example_gemm_quant_basic -j
make tile_example_gemm_bquant_basic -j
```
This will result in an executable `build/bin/tile_example_gemm_quant_basic`
@@ -37,7 +37,7 @@ args:
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-quant_mode Which quant method to use (aquant, rowcol)
-quant_mode Which quant method to use (aquant, bquant, tensor, rowcol)
```
User need to select correct mapping of config for each quant mode:

View File

@@ -66,19 +66,21 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
// row-col and tensor quants use the regular pipeline, A/B quants use their own
using PipelineProblem = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant,
ck_tile::GemmRowColQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
transpose_c,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmRowColTensorQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
transpose_c,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::QDataType,
@@ -105,7 +107,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
tail_number_v>>>;
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant,
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
@@ -241,10 +244,18 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'");
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
}
}
else if(data_type == "bf8")
@@ -276,10 +287,18 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'");
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
}
}
else if(data_type == "i4fp8")

View File

@@ -9,7 +9,7 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
@@ -241,7 +241,7 @@ auto create_args(int argc, char* argv[])
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1", "rotating count, defaults to 1")
.insert("quant_mode", "aquant", "Choose aquant (default), bquant or rowcol");
.insert("quant_mode", "aquant", "Choose aquant (default), bquant, tensor or rowcol");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -119,11 +119,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
}
std::cout << " Acc_Type = " << DataTypeTraits<typename TypeConfig::AccDataType>::name
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
<< " QuantMode = "
<< (QuantMode == ck_tile::QuantType::AQuantGrouped
? "AQuantGrouped"
: (QuantMode == ck_tile::QuantType::BQuantGrouped ? "BQuantGrouped"
: "RowColQuant"))
<< " QuantMode = " << quant_type_to_string(QuantMode)
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
@@ -183,10 +179,11 @@ int run_gemm_example_with_layouts(int argc,
AQK = 0; // No A quantization
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
AQK = 1; // Row quantization: tensor shape [M, 1]
BQK = N; // Column quantization: tensor shape [1, N]
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
}
else
{
@@ -227,6 +224,11 @@ int run_gemm_example_with_layouts(int argc,
stride_AQ = ck_tile::get_default_stride(M, 1, stride_AQ, is_row_major(aq_layout));
stride_BQ = ck_tile::get_default_stride(1, N, stride_BQ, is_row_major(bq_layout));
}
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
stride_AQ = 1; // Tensor quantization: tensor shape [1]
stride_BQ = 1; // Tensor quantization: tensor shape [1]
}
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
@@ -237,28 +239,30 @@ int run_gemm_example_with_layouts(int argc,
// Create AQ tensor with appropriate shape
std::unique_ptr<ck_tile::HostTensor<AQDataType>> aq_tensor_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
}
else if(QuantMode == ck_tile::QuantType::RowColQuant)
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
ck_tile::host_tensor_descriptor(1, 1, stride_AQ, is_row_major(aq_layout)));
}
// Create BQ tensor only for RowColQuant mode
// Create BQ tensor with appropriate shape
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(1, N, stride_BQ, is_row_major(bq_layout)));
ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
}
std::random_device rd;
@@ -282,7 +286,7 @@ int run_gemm_example_with_layouts(int argc,
*bq_tensor_ptr);
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
@@ -296,12 +300,15 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
}
else if(init_method == 1)
@@ -343,7 +350,8 @@ int run_gemm_example_with_layouts(int argc,
std::unique_ptr<ck_tile::DeviceMem> aq_dev_buf_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
aq_dev_buf_ptr =
std::make_unique<ck_tile::DeviceMem>(aq_tensor_ptr->get_element_space_size_in_bytes());
@@ -351,14 +359,16 @@ int run_gemm_example_with_layouts(int argc,
std::unique_ptr<ck_tile::DeviceMem> bq_dev_buf_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_dev_buf_ptr =
std::make_unique<ck_tile::DeviceMem>(bq_tensor_ptr->get_element_space_size_in_bytes());
}
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
if constexpr(GemmConfig::PreshuffleQuant)
{
@@ -398,7 +408,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_result.SetZero();
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
}
@@ -412,15 +423,9 @@ int run_gemm_example_with_layouts(int argc,
CLayout,
QuantGroupSize,
QuantMode>(a_m_k_dev_buf,
(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
? aq_dev_buf_ptr.get()
: nullptr,
aq_dev_buf_ptr.get(),
b_k_n_dev_buf,
(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
? bq_dev_buf_ptr.get()
: nullptr,
bq_dev_buf_ptr.get(),
c_m_n_dev_buf,
M,
N,
@@ -467,7 +472,7 @@ int run_gemm_example_with_layouts(int argc,
QuantGroupSize,
false>(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
}
else
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::reference_gemm_rowcol_quant<ADataType,
AQDataType,
@@ -477,6 +482,16 @@ int run_gemm_example_with_layouts(int argc,
CDataType>(
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
}
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
ck_tile::reference_gemm_tensor_quant<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType>(
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
}
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());