mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
[CK_TILE] Tensor-wise scaled quant gemm kernel (#2846)
* rename gemm_group_quant to gemm_quant
* Add TensorWise quant mode
* Cshuffle epilogue tests with tensor scaling
* Add tensor quant to example
* Don't use readfirstlane for reading scales - doesn't work for some reason
* Add to changelog
* revert include - from a merge problem?
* revert common.hpp include
* revert host.hpp include
* remove unused utility function
* rename quant pipeline problem
* refactor quant tests
* remove aquant utils
* use TEST_F
* fix all tests by changing gemm config
* Use typed tests
* fix copyright
[ROCm/composable_kernel commit: 4363a82bd6]
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "quant_grouped_gemm.hpp"
|
||||
|
||||
@@ -65,15 +65,15 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmRowColQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
BDataType,
|
||||
scheduler>;
|
||||
using QuantGemmProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
BDataType,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;
|
||||
|
||||
@@ -5,6 +5,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming
|
||||
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
|
||||
- BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline
|
||||
- Row and Column-wise scaled: scaling implemented in Epilogue
|
||||
- Tensor-wise scaled: scaling implemented in Epilogue
|
||||
|
||||
## build
|
||||
```
|
||||
@@ -14,7 +15,6 @@ mkdir build && cd build
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# Compile the quant kernels
|
||||
make tile_example_gemm_quant_basic -j
|
||||
make tile_example_gemm_bquant_basic -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_quant_basic`
|
||||
|
||||
@@ -37,7 +37,7 @@ args:
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-quant_mode Which quant method to use (aquant, rowcol)
|
||||
-quant_mode Which quant method to use (aquant, bquant, tensor, rowcol)
|
||||
```
|
||||
|
||||
User need to select correct mapping of config for each quant mode:
|
||||
|
||||
@@ -66,19 +66,21 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
// row-col and tensor quants use the regular pipeline, A/B quants use their own
|
||||
using PipelineProblem = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant,
|
||||
ck_tile::GemmRowColQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::QDataType,
|
||||
@@ -105,7 +107,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
tail_number_v>>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant,
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
@@ -241,10 +244,18 @@ int run_gemm_example(int argc, char* argv[])
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'");
|
||||
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
@@ -276,10 +287,18 @@ int run_gemm_example(int argc, char* argv[])
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'");
|
||||
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
|
||||
}
|
||||
}
|
||||
else if(data_type == "i4fp8")
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
@@ -241,7 +241,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1", "rotating count, defaults to 1")
|
||||
.insert("quant_mode", "aquant", "Choose aquant (default), bquant or rowcol");
|
||||
.insert("quant_mode", "aquant", "Choose aquant (default), bquant, tensor or rowcol");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -119,11 +119,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
}
|
||||
std::cout << " Acc_Type = " << DataTypeTraits<typename TypeConfig::AccDataType>::name
|
||||
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
|
||||
<< " QuantMode = "
|
||||
<< (QuantMode == ck_tile::QuantType::AQuantGrouped
|
||||
? "AQuantGrouped"
|
||||
: (QuantMode == ck_tile::QuantType::BQuantGrouped ? "BQuantGrouped"
|
||||
: "RowColQuant"))
|
||||
<< " QuantMode = " << quant_type_to_string(QuantMode)
|
||||
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
|
||||
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
@@ -183,10 +179,11 @@ int run_gemm_example_with_layouts(int argc,
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1]
|
||||
BQK = N; // Column quantization: tensor shape [1, N]
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
|
||||
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -227,6 +224,11 @@ int run_gemm_example_with_layouts(int argc,
|
||||
stride_AQ = ck_tile::get_default_stride(M, 1, stride_AQ, is_row_major(aq_layout));
|
||||
stride_BQ = ck_tile::get_default_stride(1, N, stride_BQ, is_row_major(bq_layout));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
stride_AQ = 1; // Tensor quantization: tensor shape [1]
|
||||
stride_BQ = 1; // Tensor quantization: tensor shape [1]
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
@@ -237,28 +239,30 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
// Create AQ tensor with appropriate shape
|
||||
std::unique_ptr<ck_tile::HostTensor<AQDataType>> aq_tensor_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
|
||||
}
|
||||
else if(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_AQ, is_row_major(aq_layout)));
|
||||
}
|
||||
|
||||
// Create BQ tensor only for RowColQuant mode
|
||||
// Create BQ tensor with appropriate shape
|
||||
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(1, N, stride_BQ, is_row_major(bq_layout)));
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
@@ -282,7 +286,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
*bq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
@@ -296,12 +300,15 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else if(init_method == 1)
|
||||
@@ -343,7 +350,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> aq_dev_buf_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
aq_dev_buf_ptr =
|
||||
std::make_unique<ck_tile::DeviceMem>(aq_tensor_ptr->get_element_space_size_in_bytes());
|
||||
@@ -351,14 +359,16 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> bq_dev_buf_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_dev_buf_ptr =
|
||||
std::make_unique<ck_tile::DeviceMem>(bq_tensor_ptr->get_element_space_size_in_bytes());
|
||||
}
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
@@ -398,7 +408,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
|
||||
}
|
||||
@@ -412,15 +423,9 @@ int run_gemm_example_with_layouts(int argc,
|
||||
CLayout,
|
||||
QuantGroupSize,
|
||||
QuantMode>(a_m_k_dev_buf,
|
||||
(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
? aq_dev_buf_ptr.get()
|
||||
: nullptr,
|
||||
aq_dev_buf_ptr.get(),
|
||||
b_k_n_dev_buf,
|
||||
(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
? bq_dev_buf_ptr.get()
|
||||
: nullptr,
|
||||
bq_dev_buf_ptr.get(),
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
@@ -467,7 +472,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
|
||||
}
|
||||
else
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
ck_tile::reference_gemm_rowcol_quant<ADataType,
|
||||
AQDataType,
|
||||
@@ -477,6 +482,16 @@ int run_gemm_example_with_layouts(int argc,
|
||||
CDataType>(
|
||||
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
ck_tile::reference_gemm_tensor_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType>(
|
||||
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
|
||||
}
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
|
||||
Reference in New Issue
Block a user