[CK_TILE] Tensor-wise scaled quant gemm kernel (#2846)

* rename gemm_group_quant to gemm_quant

* Add TensorWise quant mode

* Cshuffle epilogue tests with tensor scaling

* Add tensor quant to example

* Don't use readfirstlane for reading scales - doesn't work for some reason

* Add to changelog

* revert include - from a merge problem?

* revert common.hpp include

* revert host.hpp include

* remove unused utility function

* rename quant pipeline problem

* refactor quant tests

* remove aquant utils

* use TEST_F

* fix all tests by changing gemm config

* Use typed tests

* fix copyright
This commit is contained in:
Sami Remes
2025-09-20 02:52:35 +03:00
committed by GitHub
parent b765fe78f3
commit 4363a82bd6
39 changed files with 1555 additions and 1056 deletions

View File

@@ -31,6 +31,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added benchmarking support for tile engine GEMM Multi D.
* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands.
* Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM.
* Added tensor-wise quantization for CK_TILE GEMM
### Optimized

View File

@@ -13,7 +13,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
#include "ck_tile/host.hpp"
#include "quant_grouped_gemm.hpp"
@@ -65,15 +65,15 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
constexpr auto memory_operation = memory_operation_.value;
constexpr bool transpose_c = false;
using QuantGemmProblem = ck_tile::GemmRowColQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>;
using QuantGemmProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;

View File

@@ -5,6 +5,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
- BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline
- Row and Column-wise scaled: scaling implemented in Epilogue
- Tensor-wise scaled: scaling implemented in Epilogue
## build
```
@@ -14,7 +15,6 @@ mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch>
# Compile the quant kernels
make tile_example_gemm_quant_basic -j
make tile_example_gemm_bquant_basic -j
```
This will result in an executable `build/bin/tile_example_gemm_quant_basic`
@@ -37,7 +37,7 @@ args:
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-quant_mode Which quant method to use (aquant, rowcol)
-quant_mode Which quant method to use (aquant, bquant, tensor, rowcol)
```
User need to select correct mapping of config for each quant mode:

View File

@@ -66,19 +66,21 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
// row-col and tensor quants use the regular pipeline, A/B quants use their own
using PipelineProblem = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant,
ck_tile::GemmRowColQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
transpose_c,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmRowColTensorQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
transpose_c,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::QDataType,
@@ -105,7 +107,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
tail_number_v>>>;
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant,
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
@@ -241,10 +244,18 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'");
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
}
}
else if(data_type == "bf8")
@@ -276,10 +287,18 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'");
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
}
}
else if(data_type == "i4fp8")

View File

@@ -9,7 +9,7 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
@@ -241,7 +241,7 @@ auto create_args(int argc, char* argv[])
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1", "rotating count, defaults to 1")
.insert("quant_mode", "aquant", "Choose aquant (default), bquant or rowcol");
.insert("quant_mode", "aquant", "Choose aquant (default), bquant, tensor or rowcol");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -119,11 +119,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
}
std::cout << " Acc_Type = " << DataTypeTraits<typename TypeConfig::AccDataType>::name
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
<< " QuantMode = "
<< (QuantMode == ck_tile::QuantType::AQuantGrouped
? "AQuantGrouped"
: (QuantMode == ck_tile::QuantType::BQuantGrouped ? "BQuantGrouped"
: "RowColQuant"))
<< " QuantMode = " << quant_type_to_string(QuantMode)
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
@@ -183,10 +179,11 @@ int run_gemm_example_with_layouts(int argc,
AQK = 0; // No A quantization
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
AQK = 1; // Row quantization: tensor shape [M, 1]
BQK = N; // Column quantization: tensor shape [1, N]
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
}
else
{
@@ -227,6 +224,11 @@ int run_gemm_example_with_layouts(int argc,
stride_AQ = ck_tile::get_default_stride(M, 1, stride_AQ, is_row_major(aq_layout));
stride_BQ = ck_tile::get_default_stride(1, N, stride_BQ, is_row_major(bq_layout));
}
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
stride_AQ = 1; // Tensor quantization: tensor shape [1]
stride_BQ = 1; // Tensor quantization: tensor shape [1]
}
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
@@ -237,28 +239,30 @@ int run_gemm_example_with_layouts(int argc,
// Create AQ tensor with appropriate shape
std::unique_ptr<ck_tile::HostTensor<AQDataType>> aq_tensor_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
}
else if(QuantMode == ck_tile::QuantType::RowColQuant)
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
ck_tile::host_tensor_descriptor(1, 1, stride_AQ, is_row_major(aq_layout)));
}
// Create BQ tensor only for RowColQuant mode
// Create BQ tensor with appropriate shape
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(1, N, stride_BQ, is_row_major(bq_layout)));
ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
}
std::random_device rd;
@@ -282,7 +286,7 @@ int run_gemm_example_with_layouts(int argc,
*bq_tensor_ptr);
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
@@ -296,12 +300,15 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
}
else if(init_method == 1)
@@ -343,7 +350,8 @@ int run_gemm_example_with_layouts(int argc,
std::unique_ptr<ck_tile::DeviceMem> aq_dev_buf_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
aq_dev_buf_ptr =
std::make_unique<ck_tile::DeviceMem>(aq_tensor_ptr->get_element_space_size_in_bytes());
@@ -351,14 +359,16 @@ int run_gemm_example_with_layouts(int argc,
std::unique_ptr<ck_tile::DeviceMem> bq_dev_buf_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_dev_buf_ptr =
std::make_unique<ck_tile::DeviceMem>(bq_tensor_ptr->get_element_space_size_in_bytes());
}
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
if constexpr(GemmConfig::PreshuffleQuant)
{
@@ -398,7 +408,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_result.SetZero();
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
}
@@ -412,15 +423,9 @@ int run_gemm_example_with_layouts(int argc,
CLayout,
QuantGroupSize,
QuantMode>(a_m_k_dev_buf,
(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
? aq_dev_buf_ptr.get()
: nullptr,
aq_dev_buf_ptr.get(),
b_k_n_dev_buf,
(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
? bq_dev_buf_ptr.get()
: nullptr,
bq_dev_buf_ptr.get(),
c_m_n_dev_buf,
M,
N,
@@ -467,7 +472,7 @@ int run_gemm_example_with_layouts(int argc,
QuantGroupSize,
false>(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
}
else
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::reference_gemm_rowcol_quant<ADataType,
AQDataType,
@@ -477,6 +482,16 @@ int run_gemm_example_with_layouts(int argc,
CDataType>(
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
}
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
ck_tile::reference_gemm_tensor_quant<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType>(
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
}
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -158,4 +158,7 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<Windo
{
}
template <typename Tile>
concept IsLoadableTile = requires { load_tile(std::declval<Tile>()); };
} // namespace ck_tile

View File

@@ -180,10 +180,6 @@ CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
@@ -198,7 +194,57 @@ CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
std::cout << std::endl;
}
template <typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<AQDataType>& aq_1_1,
const HostTensor<BDataType>& b_k_n,
const HostTensor<BQDataType>& bq_1_1,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
static_assert(std::is_same_v<AccDataType, float>);
static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
// Init accumulator
AccDataType v_acc = 0;
// Get scale for A and scale for B
const AccDataType a_scale = ck_tile::type_convert<AccDataType>(aq_1_1(0, 0));
const AccDataType b_scale = ck_tile::type_convert<AccDataType>(bq_1_1(0, 0));
// Compute the dot product
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
AccDataType v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
v_acc += v_a * v_b;
}
v_acc = v_acc * a_scale * b_scale;
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,

View File

@@ -304,22 +304,41 @@ struct CShuffleEpilogue
CK_TILE_DEVICE void
scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
{
// Load tiles
const auto scale_m_tile = load_tile(scale_m_window);
const auto scale_n_tile = load_tile(scale_n_window);
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
tile_elementwise_inout(
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
// Move scale windows
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
// Check if scales are EmptyScale first (no scaling needed)
if constexpr(std::is_same_v<ScaleM, EmptyScale> && std::is_same_v<ScaleN, EmptyScale>)
{
constexpr auto step = SFC::get_forward_step(iAccess);
// No scaling needed - this is a no-op
}
// Check if scales are scalar AccDataType
else if constexpr(std::is_same_v<ScaleM, AccDataType> &&
std::is_same_v<ScaleN, AccDataType>)
{
// Handle scalar scales
const AccDataType scale_m = scale_m_window;
const AccDataType scale_n = scale_n_window;
tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; },
lds_tile);
}
// Otherwise, assume they are tile windows that can be loaded
else
{
// Load tiles
const auto scale_m_tile = load_tile(scale_m_window);
const auto scale_n_tile = load_tile(scale_n_window);
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
tile_elementwise_inout(
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
// Move scale windows
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
}
}
}
@@ -452,6 +471,8 @@ struct CShuffleEpilogue
// Optional scales (must share the same distribution to match per-thread indexing)
constexpr bool has_scales =
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
constexpr bool has_scalar_scales =
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
// Tiles to hold row/col scales when present
using SMType = typename GetDataType<remove_cvref_t<ScaleM>>::type;
@@ -462,8 +483,11 @@ struct CShuffleEpilogue
// Build windows only if scales are provided
auto scale_m_window = [&]() {
if constexpr(has_scales)
if constexpr(has_scales && !has_scalar_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
"ScaleM must be a loadable tile");
return make_tile_window(scale_m, dram_tile_distribution);
}
else
@@ -472,8 +496,11 @@ struct CShuffleEpilogue
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scales)
if constexpr(has_scales && !has_scalar_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
"ScaleN must be a loadable tile");
return make_tile_window(scale_n, dram_tile_distribution);
}
else
@@ -489,7 +516,7 @@ struct CShuffleEpilogue
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
// If scales provided, load them with identical distribution
if constexpr(has_scales)
if constexpr(has_scales && IsLoadableTile<ScaleM> && IsLoadableTile<ScaleN>)
{
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
@@ -504,7 +531,11 @@ struct CShuffleEpilogue
auto emit = [&](index_t out_idx, index_t src_row) {
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
if constexpr(has_scales)
if constexpr(has_scalar_scales)
{
v = static_cast<AccDataType>(v * scale_m * scale_n);
}
else if constexpr(has_scales)
{
// same linear index mapping on the permuted distribution
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
@@ -595,10 +626,19 @@ struct CShuffleEpilogue
number<NumDTensor>{});
constexpr bool has_scales =
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
!std::is_same_v<ScaleM, EmptyScale> && !std::is_same_v<ScaleN, EmptyScale>;
constexpr bool has_scalar_scales =
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
auto scale_m_window = [&]() {
if constexpr(has_scales)
if constexpr(has_scalar_scales)
{
return scale_m;
}
else if constexpr(has_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
"ScaleM must be a loadable tile");
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
}
else
@@ -607,8 +647,15 @@ struct CShuffleEpilogue
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scales)
if constexpr(has_scalar_scales)
{
return scale_n;
}
else if constexpr(has_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
"ScaleN must be a loadable tile");
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
}
else

View File

@@ -1,21 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -0,0 +1,21 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -12,7 +12,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
namespace ck_tile {
@@ -330,7 +330,6 @@ struct QuantGemmKernel
}
}
// NOTE: no kernel currently uses BQuant like this:
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
@@ -890,6 +889,7 @@ struct QuantGemmKernel
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param aq_ptr input AQ pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
@@ -938,7 +938,8 @@ struct QuantGemmKernel
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
else if constexpr(kQuantType == QuantType::RowColQuant ||
kQuantType == QuantType::TensorQuant)
{
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
@@ -964,6 +965,18 @@ struct QuantGemmKernel
aq_block_window,
bq_block_window);
}
else if constexpr(kQuantType == QuantType::TensorQuant)
{
// TODO: why doesn't readfirstlane work here?
// const AccDataType aq_scale =
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*aq_ptr));
// const AccDataType bq_scale =
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*bq_ptr));
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
}
}
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const

View File

@@ -9,7 +9,7 @@
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/host.hpp"
#include <hip/hip_runtime.h>

View File

@@ -9,7 +9,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {

View File

@@ -9,7 +9,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {

View File

@@ -168,17 +168,18 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
using GemmRowColQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
AccDataType_,
BDataType_,
AccDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
1, // no group size applicable
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
using GemmRowColTensorQuantPipelineProblem =
GemmQuantPipelineProblemBase<ADataType_,
AccDataType_,
BDataType_,
AccDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
1, // no group size applicable
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
} // namespace ck_tile

View File

@@ -12,9 +12,22 @@ enum struct QuantType : std::uint16_t
{
AQuantGrouped = 0,
BQuantGrouped = 1,
RowColQuant = 2
RowColQuant = 2,
TensorQuant = 3
};
std::string quant_type_to_string(QuantType quant_type)
{
switch(quant_type)
{
case QuantType::AQuantGrouped: return "AQuantGrouped";
case QuantType::BQuantGrouped: return "BQuantGrouped";
case QuantType::RowColQuant: return "RowColQuant";
case QuantType::TensorQuant: return "TensorQuant";
default: return "Unknown";
}
}
template <bool kPadM_,
bool kPadN_,
bool kPadK_,

View File

@@ -41,8 +41,8 @@ TEST_F(CShuffleEpilogueTest, BasicHalfTest)
NPerXdl,
KPerXdl>;
bool result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>();
EXPECT_TRUE(result) << "Basic CShuffleEpilogue test failed";
auto result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::None);
EXPECT_FLOAT_EQ(result[0], 2.0F) << "Basic CShuffleEpilogue test failed";
}
TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)
@@ -73,8 +73,45 @@ TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)
NPerXdl,
KPerXdl>;
bool result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(true);
EXPECT_TRUE(result) << "Scale CShuffleEpilogue test failed";
auto result =
run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::RowCol);
EXPECT_FLOAT_EQ(result[0], 2.0F) << "RowCol CShuffleEpilogue test failed: first element not 2";
EXPECT_FLOAT_EQ(result[1], 4.0F)
<< "RowCol CShuffleEpilogue test failed: second element not 2*2";
}
TEST_F(CShuffleEpilogueTest, BasicHalfTestWithTensorScale)
{
// Basic test configuration with half_t data types
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using ODataType = ck_tile::half_t;
constexpr index_t kMPerBlock = 256;
constexpr index_t kNPerBlock = 256;
constexpr index_t MWave = 2;
constexpr index_t NWave = 2;
constexpr index_t MPerXdl = 32;
constexpr index_t NPerXdl = 32;
constexpr index_t KPerXdl = 8;
using TestProblem = SimpleCShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
ODataType,
kMPerBlock,
kNPerBlock,
MWave,
NWave,
MPerXdl,
NPerXdl,
KPerXdl>;
auto result =
run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::Tensor);
EXPECT_FLOAT_EQ(result[0], 4.0F)
<< "TensorScale CShuffleEpilogue test failed: first element not 2*2=4";
}
int main(int argc, char** argv)

View File

@@ -19,8 +19,15 @@
namespace ck_tile {
enum class ScaleType
{
None,
RowCol,
Tensor
};
// Simple test kernel to invoke the CShuffleEpilogue
template <typename Problem, index_t M, index_t N, bool UseScale>
template <typename Problem, index_t M, index_t N, ScaleType Scale>
__global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __restrict__ output_data,
float* m_scale,
float* n_scale)
@@ -61,7 +68,7 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res
auto empty_ds = make_tuple();
// Call the epilogue
if constexpr(UseScale)
if constexpr(Scale == ScaleType::RowCol)
{
const auto m_scale_window = make_tile_window(
make_naive_tensor_view<address_space_enum::global>(
@@ -75,6 +82,10 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res
{0, 0});
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem, m_scale_window, n_scale_window);
}
else if constexpr(Scale == ScaleType::Tensor)
{
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem, *m_scale, *n_scale);
}
else
{
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem);
@@ -113,7 +124,7 @@ using SimpleCShuffleEpilogueProblem =
memory_operation_enum::set>;
template <typename Problem, index_t M, index_t N>
bool run_cshuffle_epilogue_test(bool use_scale = false)
auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None)
{
using ODataType = typename Problem::ODataType;
@@ -142,7 +153,7 @@ bool run_cshuffle_epilogue_test(bool use_scale = false)
dim3 gridSize(1, 1, 1);
dim3 blockSize(kBlockSize, 1, 1);
if(use_scale)
if(scale == ScaleType::RowCol)
{
float* m_scale;
float* n_scale;
@@ -155,12 +166,25 @@ bool run_cshuffle_epilogue_test(bool use_scale = false)
hipMemcpy(m_scale, h_m_scale.data(), M * sizeof(float), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(
hipMemcpy(n_scale, h_n_scale.data(), N * sizeof(float), hipMemcpyHostToDevice));
test_cshuffle_epilogue_kernel<Problem, M, N, true>
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::RowCol>
<<<gridSize, blockSize>>>(device_output, m_scale, n_scale);
}
else if(scale == ScaleType::Tensor)
{
float* m_scale;
float* n_scale;
std::vector<float> h_m_scale(1, 2.0F);
std::vector<float> h_n_scale(1, 1.0F);
HIP_CHECK_ERROR(hipMalloc(&m_scale, sizeof(float)));
HIP_CHECK_ERROR(hipMalloc(&n_scale, sizeof(float)));
HIP_CHECK_ERROR(hipMemcpy(m_scale, h_m_scale.data(), sizeof(float), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(n_scale, h_n_scale.data(), sizeof(float), hipMemcpyHostToDevice));
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::Tensor>
<<<gridSize, blockSize>>>(device_output, m_scale, n_scale);
}
else
{
test_cshuffle_epilogue_kernel<Problem, M, N, false>
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::None>
<<<gridSize, blockSize>>>(device_output, nullptr, nullptr);
}
@@ -172,20 +196,10 @@ bool run_cshuffle_epilogue_test(bool use_scale = false)
HIP_CHECK_ERROR(hipMemcpy(
host_output.data(), device_output, output_size * sizeof(ODataType), hipMemcpyDeviceToHost));
// Basic verification - just check that output has a 2, and 4 if using scaling
bool has_2 =
type_convert<float>(host_output[0]) > 1.9F && type_convert<float>(host_output[0]) < 2.1F;
bool scale_has_4 = true;
if(use_scale)
{
scale_has_4 = type_convert<float>(host_output[1]) > 3.9F &&
type_convert<float>(host_output[1]) < 4.1F;
}
// Cleanup
HIP_CHECK_ERROR(hipFree(device_output));
return has_2 && scale_has_4;
return host_output;
}
} // namespace ck_tile

View File

@@ -6,14 +6,9 @@ endif()
list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
set(TEST_GEMM_NAME test_tile_gemm_aquant_basic)
set(QUANT_TYPES fp8 bf8 i4fp8 i4bf8 i4f32fp8 i4f32bf8)
foreach(QUANT_TYPE ${QUANT_TYPES})
add_gtest_executable(${TEST_GEMM_NAME}_${QUANT_TYPE} test_gemm_aquant_basic_${QUANT_TYPE}.cpp)
target_compile_options(${TEST_GEMM_NAME}_${QUANT_TYPE} PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
endforeach()
# Typed Test Suite for GEMM Quantization
add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp)
target_compile_options(test_tile_gemm_quant_typed PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
endif()

View File

@@ -1,6 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("bf8"); }

View File

@@ -1,6 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("fp8"); }

View File

@@ -1,6 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("i4bf8"); }

View File

@@ -1,6 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("i4f32bf8"); }

View File

@@ -1,6 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("i4f32fp8"); }

View File

@@ -1,6 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("i4fp8"); }

View File

@@ -1,243 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#define CK_TILE_PIPELINE_PREFILL 1
#define CK_TILE_PIPELINE_DECODE 2
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
class ArgumentsNotSupportedException : public std::logic_error
{
public:
explicit ArgumentsNotSupportedException(const std::string& message) : logic_error(message) {}
};
struct GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = false;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename PrecType>
struct GemmConfigDecode : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
};
template <typename PrecType>
struct GemmConfigPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
};
template <typename PrecType>
struct GemmConfigPreshuffleQuant : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
static constexpr bool PreshuffleQuant = true;
};
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,
typename QDataType_ = float>
struct GemmQuantTypeConfig
{
using ADataType = ADataType_;
using QDataType = QDataType_;
using BDataType = BDataType_;
using AccDataType = float;
using CDataType = CDataType_;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("aq_layout", "R", "Aq tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_q", "0", "Tensor AQ stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent")
.insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -0,0 +1,179 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <tuple>
#include <stdexcept>
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
// Forward declarations for quant type-specific implementations
template <ck_tile::QuantType QT>
struct QuantTypeTraits;
// Base class for common quant gemm functionality
template <typename Tuple, typename Derived>
class TestCkTileGemmQuantBase : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using QDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto QuantType = std::tuple_element_t<7, Tuple>::value;
using GemmConfig = std::tuple_element_t<8, Tuple>;
static constexpr uint32_t QuantGroupSize = std::tuple_element_t<9, Tuple>::value;
using AccDataType = float; // accumulate always in float
// Get the quant-type specific data types from traits
using QuantTraits = QuantTypeTraits<QuantType>;
using ComputeDataType = typename QuantTraits::template ComputeDataType<ADataType, BDataType>;
static constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
static constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
static constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
static constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
static constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
static constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
public:
void SetUp() override { static_cast<Derived*>(this)->SetUpQuantTypeSpecific(); }
void TearDown() override { static_cast<Derived*>(this)->TearDownQuantTypeSpecific(); }
// Common test execution logic
void invoke_quant_gemm(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kPreshuffle = false;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
kPreshuffle,
ALayout,
BLayout,
CLayout,
QuantType>;
// Let the derived class create the appropriate pipeline and epilogue
static_cast<Derived*>(this)
->template run_quant_gemm_impl<CodegenGemmShape, TilePartitioner, CodegenGemmTraits>(
args, s);
}
void RunTest(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
// Generate test data and run the kernel
static_cast<Derived*>(this)->run_test_with_validation(M, N, K);
}
// Helper function to check layout
template <typename Layout>
static constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(Layout{})>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Tolerance calculation function for validation
template <typename ADataType_, typename BDataType_, typename AccDataType_, typename CDataType_>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType_, AccDataType_>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType_, AccDataType_>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType_, CDataType_, CDataType_>(kbatch);
const auto atol_split_k =
ck_tile::get_absolute_threshold<CDataType_, CDataType_, CDataType_>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
};
// Define generic QuantTypeTraits template (will be specialized)
template <ck_tile::QuantType QT>
struct QuantTypeTraits
{
static_assert(QT == ck_tile::QuantType::AQuantGrouped ||
QT == ck_tile::QuantType::BQuantGrouped ||
QT == ck_tile::QuantType::RowColQuant ||
QT == ck_tile::QuantType::TensorQuant,
"Unsupported quantization type");
};
// Specialization for AQuantGrouped
template <>
struct QuantTypeTraits<ck_tile::QuantType::AQuantGrouped>
{
template <typename ADataType, typename BDataType>
using ComputeDataType = BDataType; // For AQuant, compute type is BDataType
static constexpr const char* name = "aquant";
};
// Specialization for BQuantGrouped
template <>
struct QuantTypeTraits<ck_tile::QuantType::BQuantGrouped>
{
template <typename ADataType, typename BDataType>
using ComputeDataType = ADataType; // For BQuant, compute type is ADataType
static constexpr const char* name = "bquant";
};
// Specialization for RowColQuant
template <>
struct QuantTypeTraits<ck_tile::QuantType::RowColQuant>
{
template <typename ADataType, typename BDataType>
using ComputeDataType = ADataType; // For RowColQuant, compute type is ADataType
static constexpr const char* name = "rowcol";
};
// Specialization for TensorQuant
template <>
struct QuantTypeTraits<ck_tile::QuantType::TensorQuant>
{
template <typename ADataType, typename BDataType>
using ComputeDataType = ADataType; // For TensorQuant, compute type is ADataType
static constexpr const char* name = "tensor";
};

View File

@@ -0,0 +1,919 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "test_gemm_quant_base.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
struct GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = false;
static constexpr bool DoubleSmemBuffer = false;
// Default GEMM tile sizes for tests
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
};
template <typename Tuple>
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
{
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>;
friend Base;
public:
using typename Base::AccDataType;
using typename Base::ADataType;
using typename Base::ALayout;
using typename Base::BDataType;
using typename Base::BLayout;
using typename Base::CDataType;
using typename Base::CLayout;
using typename Base::ComputeDataType;
using typename Base::QDataType;
static constexpr auto QuantType = Base::QuantType;
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
protected:
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
// AQuant-specific data generation
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
// AQuant uses grouped quantization for A matrix
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize);
const ck_tile::index_t stride_AQ =
ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(ALayout{}));
// Generate test data
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
ck_tile::HostTensor<QDataType> aq_m_aqk(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
// Initialize data with random values
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f}(a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f}(a_m_k);
}
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{-2.0f, 2.0f}(aq_m_aqk);
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size() * sizeof(QDataType));
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
// Copy to device
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<ADataType> temp = a_m_k;
ck_tile::permute_vectors_i4x4_b(temp);
a_m_k_dev_buf.ToDevice(temp.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
// Create args for kernel execution
ck_tile::QuantGemmHostArgs args{
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales)
nullptr, // bq_ptr (not used for AQuant)
1, // k_batch
M,
N,
K, // M, N, K
AQK, // QK_A
0, // QK_B (not used for AQuant)
stride_A,
stride_B,
stride_C,
stride_AQ,
0 // strides
};
// Run the kernel
ck_tile::stream_config stream_config{};
this->invoke_quant_gemm(args, stream_config);
// Validation using reference implementation
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Run reference AQuant implementation
ck_tile::reference_gemm_quant<ADataType,
QDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
// Get device result
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
// Calculate error tolerances
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1, max_accumulated_value);
// Validate results
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass) << "AQuantGrouped validation failed with M=" << M << ", N=" << N
<< ", K=" << K;
if(!pass)
{
std::cout << "AQuantGrouped - Relative error threshold: "
<< rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
}
private:
// AQuant-specific pipeline implementation
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
using PipelineProblem =
ck_tile::GemmAQuantPipelineProblem<ADataType,
QDataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transpose_c,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
ck_tile::QuantType::AQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Arguments not supported for AQuant kernel");
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
};
// BQuant-specific test fixture
template <typename Tuple>
class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmBQuant<Tuple>>
{
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmBQuant<Tuple>>;
friend Base;
public:
using typename Base::AccDataType;
using typename Base::ADataType;
using typename Base::ALayout;
using typename Base::BDataType;
using typename Base::BLayout;
using typename Base::CDataType;
using typename Base::CLayout;
using typename Base::ComputeDataType;
using typename Base::QDataType;
static constexpr auto QuantType = Base::QuantType;
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
protected:
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
// BQuant uses grouped quantization for B matrix
const ck_tile::index_t BQK = ck_tile::integer_divide_ceil(K, QuantGroupSize);
const ck_tile::index_t stride_BQ = BQK;
// Generate test data
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
ck_tile::HostTensor<QDataType> bq_bqk_n(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, this->is_row_major(BLayout{})));
// Initialize data with random values
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(bq_bqk_n);
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
ck_tile::DeviceMem bq_bqk_n_dev_buf(bq_bqk_n.get_element_space_size() * sizeof(QDataType));
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
// Copy to device
a_m_k_dev_buf.ToDevice(a_m_k.data());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> temp = b_k_n;
ck_tile::permute_vectors_i4x4_b(temp);
b_k_n_dev_buf.ToDevice(temp.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
// Create args for kernel execution
ck_tile::QuantGemmHostArgs args{
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
nullptr, // aq_ptr (not used for BQuant)
bq_bqk_n_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
1, // k_batch
M,
N,
K, // M, N, K
0, // QK_A (not used for BQuant)
BQK, // QK_B
stride_A,
stride_B,
stride_C,
0,
stride_BQ // strides
};
// Run the kernel
ck_tile::stream_config stream_config{};
this->invoke_quant_gemm(args, stream_config);
// Validation using reference implementation
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Run reference BQuant implementation
ck_tile::reference_gemm_quant<ADataType,
QDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(a_m_k, bq_bqk_n, b_k_n, c_m_n_host_ref);
// Get device result
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
// Calculate error tolerances
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1, max_accumulated_value);
// Validate results
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N
<< ", K=" << K;
if(!pass)
{
std::cout << "BQuantGrouped - Relative error threshold: "
<< rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
}
private:
// BQuant-specific pipeline implementation
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using PipelineProblem =
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
QDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
false, // transpose_c
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
ck_tile::QuantType::BQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Arguments not supported for BQuant kernel");
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
};
// RowColQuant-specific test fixture
template <typename Tuple>
class TestCkTileGemmRowColQuant
: public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmRowColQuant<Tuple>>
{
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmRowColQuant<Tuple>>;
friend Base;
public:
using typename Base::AccDataType;
using typename Base::ADataType;
using typename Base::ALayout;
using typename Base::BDataType;
using typename Base::BLayout;
using typename Base::CDataType;
using typename Base::CLayout;
using typename Base::ComputeDataType;
using typename Base::QDataType;
static constexpr auto QuantType = Base::QuantType;
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
protected:
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
// RowColQuant uses per-row and per-column scales
const ck_tile::index_t stride_row_scales = 1;
const ck_tile::index_t stride_col_scales = 1;
// Generate test data
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
ck_tile::HostTensor<QDataType> row_scales_m(ck_tile::host_tensor_descriptor(
M, 1, stride_row_scales, ck_tile::bool_constant<true>{}));
ck_tile::HostTensor<QDataType> col_scales_n(ck_tile::host_tensor_descriptor(
N, 1, stride_col_scales, ck_tile::bool_constant<true>{}));
// Initialize data with random values
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(row_scales_m);
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(col_scales_n);
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
ck_tile::DeviceMem row_scales_dev_buf(row_scales_m.get_element_space_size() *
sizeof(QDataType));
ck_tile::DeviceMem col_scales_dev_buf(col_scales_n.get_element_space_size() *
sizeof(QDataType));
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
// Copy to device
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
row_scales_dev_buf.ToDevice(row_scales_m.data());
col_scales_dev_buf.ToDevice(col_scales_n.data());
// Create args for kernel execution
ck_tile::QuantGemmHostArgs args{
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
row_scales_dev_buf.GetDeviceBuffer(), // aq_ptr (row scales)
col_scales_dev_buf.GetDeviceBuffer(), // bq_ptr (col scales)
1, // k_batch
M,
N,
K, // M, N, K
1, // QK_A (row scales)
1, // QK_B (col scales)
stride_A,
stride_B,
stride_C,
stride_row_scales,
stride_col_scales // strides
};
// Run the kernel
ck_tile::stream_config stream_config{};
this->invoke_quant_gemm(args, stream_config);
// Validation using reference implementation
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Run reference RowColQuant implementation
ck_tile::reference_gemm_rowcol_quant<ADataType,
QDataType,
BDataType,
QDataType,
AccDataType,
CDataType>(
a_m_k, row_scales_m, b_k_n, col_scales_n, c_m_n_host_ref);
// Get device result
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
// Calculate error tolerances
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1, max_accumulated_value);
// Validate results
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass) << "RowColQuant validation failed with M=" << M << ", N=" << N
<< ", K=" << K;
if(!pass)
{
std::cout << "RowColQuant - Relative error threshold: "
<< rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
}
private:
// RowColQuant-specific pipeline implementation
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<
ADataType,
BDataType,
AccDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
transpose_c,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
ck_tile::QuantType::RowColQuant>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Arguments not supported for RowColQuant kernel");
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
};
// TensorQuant-specific test fixture
template <typename Tuple>
class TestCkTileGemmTensorQuant
: public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmTensorQuant<Tuple>>
{
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmTensorQuant<Tuple>>;
friend Base;
public:
using typename Base::AccDataType;
using typename Base::ADataType;
using typename Base::ALayout;
using typename Base::BDataType;
using typename Base::BLayout;
using typename Base::CDataType;
using typename Base::CLayout;
using typename Base::ComputeDataType;
using typename Base::QDataType;
static constexpr auto QuantType = Base::QuantType;
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
protected:
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
// TensorQuant uses single scalar scale for each tensor
const ck_tile::index_t stride_scale_a = 1;
const ck_tile::index_t stride_scale_b = 1;
// Generate test data
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
ck_tile::HostTensor<QDataType> scale_a(
ck_tile::host_tensor_descriptor(1, 1, stride_scale_a, ck_tile::bool_constant<true>{}));
ck_tile::HostTensor<QDataType> scale_b(
ck_tile::host_tensor_descriptor(1, 1, stride_scale_b, ck_tile::bool_constant<true>{}));
// Initialize data with random values
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(scale_a);
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(scale_b);
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
ck_tile::DeviceMem scale_a_dev_buf(scale_a.get_element_space_size() * sizeof(QDataType));
ck_tile::DeviceMem scale_b_dev_buf(scale_b.get_element_space_size() * sizeof(QDataType));
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
// Copy to device
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
scale_a_dev_buf.ToDevice(scale_a.data());
scale_b_dev_buf.ToDevice(scale_b.data());
// Create args for kernel execution
ck_tile::QuantGemmHostArgs args{
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
scale_a_dev_buf.GetDeviceBuffer(), // aq_ptr (scale A)
scale_b_dev_buf.GetDeviceBuffer(), // bq_ptr (scale B)
1, // k_batch
M,
N,
K, // M, N, K
1, // QK_A (tensor scale)
1, // QK_B (tensor scale)
stride_A,
stride_B,
stride_C,
stride_scale_a,
stride_scale_b // strides
};
// Run the kernel
ck_tile::stream_config stream_config{};
this->invoke_quant_gemm(args, stream_config);
// Validation using reference implementation
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Run reference TensorQuant implementation
ck_tile::reference_gemm_tensor_quant<ADataType,
QDataType,
BDataType,
QDataType,
AccDataType,
CDataType>(
a_m_k, scale_a, b_k_n, scale_b, c_m_n_host_ref);
// Get device result
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
// Calculate error tolerances
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1, max_accumulated_value);
// Validate results
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass) << "TensorQuant validation failed with M=" << M << ", N=" << N
<< ", K=" << K;
if(!pass)
{
std::cout << "TensorQuant - Relative error threshold: "
<< rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
}
private:
// TensorQuant-specific pipeline implementation
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<
ADataType,
BDataType,
AccDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
transpose_c,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
ck_tile::QuantType::TensorQuant>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Arguments not supported for TensorQuant kernel");
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
};

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
using GroupSize = std::integral_constant<unsigned int, 128>;
// Type combinations for each quantization type
// clang-format off
using AQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>
>;
// clang-format on
// clang-format off
using BQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>
>;
// clang-format on
// clang-format off
using RowColQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, RowColQuant, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, RowColQuant, GemmConfigBase, GroupSize>
>;
// clang-format on
// clang-format off
using TensorQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, TensorQuant, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, TensorQuant, GemmConfigBase, GroupSize>
>;
// clang-format on
// Test suites for each quantization type
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes);
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes);
TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes);
TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes);
#include "test_gemm_quant_ut_cases.inc"

View File

@@ -0,0 +1,28 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// AQuant tests
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
// BQuant tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
// RowColQuant tests
TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
// TensorQuant tests
TYPED_TEST(TestCkTileGemmTensorQuant, TensorQuantTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -1,616 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstring>
#include <iostream>
#include <ostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include <random>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "test_gemm_aquant_utils.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ComputeDataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
false, // preshuffle
ALayout,
BLayout,
CLayout,
ck_tile::QuantType::AQuantGrouped>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using CodegenPipelineProblem =
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transposed_warp_gemm,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
CodegenGemmPipeline,
GemmEpilogue,
ck_tile::QuantType::AQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename AQLayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t AQK,
ck_tile::index_t stride_A,
ck_tile::index_t stride_AQ,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::QuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.QK_A = AQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.stride_AQ = stride_AQ;
float ave_time = gemm_calc_aquant<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
BDataType,
ALayout,
BLayout,
CLayout,
QuantGroupSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK +
sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
<< " A_Type = " << DataTypeTraits<ADataType>::name
<< " AQ_Type = " << DataTypeTraits<AQDataType>::name
<< " B_Type = " << DataTypeTraits<BDataType>::name
<< " Acc_Type = " << DataTypeTraits<AccDataType>::name
<< " C_Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
typename ALayout,
typename AQLayout,
typename BLayout,
typename CLayout>
bool run_gemm_test_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const AQLayout aq_layout = AQLayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
using ADataType = typename TypeConfig::ADataType;
using AQDataType = typename TypeConfig::QDataType;
using BDataType = typename TypeConfig::BDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
if(K % QuantGroupSize != 0)
{
throw std::runtime_error("K must be aligned with QuantGroupSize");
}
ck_tile::index_t AQK = K / QuantGroupSize;
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<AQDataType> aq_m_aqk(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
if(init_method == 0)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else if(init_method == 1)
{
std::cout << "Monotonic initialization is not supported." << std::endl;
return true;
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(aq_m_aqk);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
}
else
{
a_m_k.SetZero();
aq_m_aqk.SetZero();
b_k_n.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
ALayout,
AQLayout,
BLayout,
CLayout,
QuantGroupSize>(a_m_k_dev_buf,
aq_m_aqk_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
AQK,
stride_A,
stride_AQ,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
if(!pass)
{
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
return false;
}
return pass;
}
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
argc, argv, Row{}, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for A.");
}
return true;
}
template <template <typename PreType> typename GemmConfig>
bool run_gemm_test(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4fp8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32fp8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int run_gemm_combinations(std::string const& data_type)
{
// Define possible values for each parameter
std::vector<std::vector<std::string>> mnk_values = {{
"1",
"2048",
"5120",
},
{
"2",
"2048",
"5120",
},
{
"16",
"2048",
"5120",
},
{
"17",
"2048",
"5120",
},
{
"2047",
"5120",
"1024",
},
{
"2048",
"5120",
"1024",
}};
std::vector<std::string> prec_values = {data_type};
// We'll store all our arguments as strings first
std::vector<std::string> arg_strings = {"test_tile_gemm_aquant_basic",
"", // m placeholder
"", // n placeholder
"", // k placeholder
"", // prec placeholder
"-init=0",
"-v=1",
"-warmup=0",
"-repeat=1"};
// Create an array of const char pointers for argv
constexpr size_t ARG_COUNT = 9;
constexpr size_t ARG_MAX_LEN = 64;
char args[ARG_COUNT][ARG_MAX_LEN];
char* argv[ARG_COUNT];
// Run all combinations
bool is_success = true;
for(const auto& mnk : mnk_values)
{
arg_strings[1] = "-m=" + mnk[0];
arg_strings[2] = "-n=" + mnk[1];
arg_strings[3] = "-k=" + mnk[2];
for(const auto& prec : prec_values)
{
arg_strings[4] = "-prec=" + prec;
// Set up the argv array with pointers to the string data
for(size_t i = 0; i < ARG_COUNT; i++)
{
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
argv[i] = args[i];
}
std::cout << "Arguments received: ";
for(size_t i = 1; i < ARG_COUNT; ++i)
{
std::cout << argv[i] << " ";
}
std::cout << std::endl;
// Call the function with the current configuration
try
{
is_success = run_gemm_test<GemmConfigDecode>(ARG_COUNT, argv) && is_success;
}
catch(const ArgumentsNotSupportedException& e)
{
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
// ArgumentsNotSupportedException is not an error. Do not change is_success
}
catch(const std::runtime_error& e)
{
std::cerr << "Caught runtime error: " << e.what() << '\n';
is_success = false;
}
}
}
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}