mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Tensor-wise scaled quant gemm kernel (#2846)
* rename gemm_group_quant to gemm_quant * Add TensorWise quant mode * Cshuffle epilogue tests with tensor scaling * Add tensor quant to example * Don't use readfirstlane for reading scales - doesn't work for some reason * Add to changelog * revert include - from a merge problem? * revert common.hpp include * revert host.hpp include * remove unused utility function * rename quant pipeline problem * refactor quant tests * remove aquant utils * use TEST_F * fix all tests by changing gemm config * Use typed tests * fix copyright
This commit is contained in:
@@ -31,6 +31,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added benchmarking support for tile engine GEMM Multi D.
|
||||
* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands.
|
||||
* Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM.
|
||||
* Added tensor-wise quantization for CK_TILE GEMM
|
||||
|
||||
### Optimized
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "quant_grouped_gemm.hpp"
|
||||
|
||||
@@ -65,15 +65,15 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmRowColQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
BDataType,
|
||||
scheduler>;
|
||||
using QuantGemmProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
BDataType,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;
|
||||
|
||||
@@ -5,6 +5,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming
|
||||
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
|
||||
- BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline
|
||||
- Row and Column-wise scaled: scaling implemented in Epilogue
|
||||
- Tensor-wise scaled: scaling implemented in Epilogue
|
||||
|
||||
## build
|
||||
```
|
||||
@@ -14,7 +15,6 @@ mkdir build && cd build
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# Compile the quant kernels
|
||||
make tile_example_gemm_quant_basic -j
|
||||
make tile_example_gemm_bquant_basic -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_quant_basic`
|
||||
|
||||
@@ -37,7 +37,7 @@ args:
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-quant_mode Which quant method to use (aquant, rowcol)
|
||||
-quant_mode Which quant method to use (aquant, bquant, tensor, rowcol)
|
||||
```
|
||||
|
||||
User need to select correct mapping of config for each quant mode:
|
||||
|
||||
@@ -66,19 +66,21 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
// row-col and tensor quants use the regular pipeline, A/B quants use their own
|
||||
using PipelineProblem = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant,
|
||||
ck_tile::GemmRowColQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::QDataType,
|
||||
@@ -105,7 +107,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
tail_number_v>>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant,
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
@@ -241,10 +244,18 @@ int run_gemm_example(int argc, char* argv[])
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'");
|
||||
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
@@ -276,10 +287,18 @@ int run_gemm_example(int argc, char* argv[])
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'");
|
||||
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
|
||||
}
|
||||
}
|
||||
else if(data_type == "i4fp8")
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
@@ -241,7 +241,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1", "rotating count, defaults to 1")
|
||||
.insert("quant_mode", "aquant", "Choose aquant (default), bquant or rowcol");
|
||||
.insert("quant_mode", "aquant", "Choose aquant (default), bquant, tensor or rowcol");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -119,11 +119,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
}
|
||||
std::cout << " Acc_Type = " << DataTypeTraits<typename TypeConfig::AccDataType>::name
|
||||
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
|
||||
<< " QuantMode = "
|
||||
<< (QuantMode == ck_tile::QuantType::AQuantGrouped
|
||||
? "AQuantGrouped"
|
||||
: (QuantMode == ck_tile::QuantType::BQuantGrouped ? "BQuantGrouped"
|
||||
: "RowColQuant"))
|
||||
<< " QuantMode = " << quant_type_to_string(QuantMode)
|
||||
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
|
||||
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
@@ -183,10 +179,11 @@ int run_gemm_example_with_layouts(int argc,
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1]
|
||||
BQK = N; // Column quantization: tensor shape [1, N]
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
|
||||
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -227,6 +224,11 @@ int run_gemm_example_with_layouts(int argc,
|
||||
stride_AQ = ck_tile::get_default_stride(M, 1, stride_AQ, is_row_major(aq_layout));
|
||||
stride_BQ = ck_tile::get_default_stride(1, N, stride_BQ, is_row_major(bq_layout));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
stride_AQ = 1; // Tensor quantization: tensor shape [1]
|
||||
stride_BQ = 1; // Tensor quantization: tensor shape [1]
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
@@ -237,28 +239,30 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
// Create AQ tensor with appropriate shape
|
||||
std::unique_ptr<ck_tile::HostTensor<AQDataType>> aq_tensor_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
|
||||
}
|
||||
else if(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_AQ, is_row_major(aq_layout)));
|
||||
}
|
||||
|
||||
// Create BQ tensor only for RowColQuant mode
|
||||
// Create BQ tensor with appropriate shape
|
||||
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(1, N, stride_BQ, is_row_major(bq_layout)));
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
@@ -282,7 +286,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
*bq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
@@ -296,12 +300,15 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else if(init_method == 1)
|
||||
@@ -343,7 +350,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> aq_dev_buf_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
aq_dev_buf_ptr =
|
||||
std::make_unique<ck_tile::DeviceMem>(aq_tensor_ptr->get_element_space_size_in_bytes());
|
||||
@@ -351,14 +359,16 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> bq_dev_buf_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_dev_buf_ptr =
|
||||
std::make_unique<ck_tile::DeviceMem>(bq_tensor_ptr->get_element_space_size_in_bytes());
|
||||
}
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
@@ -398,7 +408,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
|
||||
}
|
||||
@@ -412,15 +423,9 @@ int run_gemm_example_with_layouts(int argc,
|
||||
CLayout,
|
||||
QuantGroupSize,
|
||||
QuantMode>(a_m_k_dev_buf,
|
||||
(QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
? aq_dev_buf_ptr.get()
|
||||
: nullptr,
|
||||
aq_dev_buf_ptr.get(),
|
||||
b_k_n_dev_buf,
|
||||
(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
? bq_dev_buf_ptr.get()
|
||||
: nullptr,
|
||||
bq_dev_buf_ptr.get(),
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
@@ -467,7 +472,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
|
||||
}
|
||||
else
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
ck_tile::reference_gemm_rowcol_quant<ADataType,
|
||||
AQDataType,
|
||||
@@ -477,6 +482,16 @@ int run_gemm_example_with_layouts(int argc,
|
||||
CDataType>(
|
||||
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
ck_tile::reference_gemm_tensor_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType>(
|
||||
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
|
||||
}
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -158,4 +158,7 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<Windo
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Tile>
|
||||
concept IsLoadableTile = requires { load_tile(std::declval<Tile>()); };
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -180,10 +180,6 @@ CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
@@ -198,7 +194,57 @@ CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<AQDataType>& aq_1_1,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
const HostTensor<BQDataType>& bq_1_1,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
|
||||
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
|
||||
static_assert(std::is_same_v<AccDataType, float>);
|
||||
static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
|
||||
static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
// Init accumulator
|
||||
AccDataType v_acc = 0;
|
||||
// Get scale for A and scale for B
|
||||
const AccDataType a_scale = ck_tile::type_convert<AccDataType>(aq_1_1(0, 0));
|
||||
const AccDataType b_scale = ck_tile::type_convert<AccDataType>(bq_1_1(0, 0));
|
||||
|
||||
// Compute the dot product
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
AccDataType v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
v_acc = v_acc * a_scale * b_scale;
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
|
||||
@@ -304,22 +304,41 @@ struct CShuffleEpilogue
|
||||
CK_TILE_DEVICE void
|
||||
scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
|
||||
{
|
||||
// Load tiles
|
||||
const auto scale_m_tile = load_tile(scale_m_window);
|
||||
const auto scale_n_tile = load_tile(scale_n_window);
|
||||
|
||||
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
|
||||
tile_elementwise_inout(
|
||||
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
|
||||
|
||||
// Move scale windows
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
// Check if scales are EmptyScale first (no scaling needed)
|
||||
if constexpr(std::is_same_v<ScaleM, EmptyScale> && std::is_same_v<ScaleN, EmptyScale>)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
// No scaling needed - this is a no-op
|
||||
}
|
||||
// Check if scales are scalar AccDataType
|
||||
else if constexpr(std::is_same_v<ScaleM, AccDataType> &&
|
||||
std::is_same_v<ScaleN, AccDataType>)
|
||||
{
|
||||
// Handle scalar scales
|
||||
const AccDataType scale_m = scale_m_window;
|
||||
const AccDataType scale_n = scale_n_window;
|
||||
tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; },
|
||||
lds_tile);
|
||||
}
|
||||
// Otherwise, assume they are tile windows that can be loaded
|
||||
else
|
||||
{
|
||||
// Load tiles
|
||||
const auto scale_m_tile = load_tile(scale_m_window);
|
||||
const auto scale_n_tile = load_tile(scale_n_window);
|
||||
|
||||
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
|
||||
tile_elementwise_inout(
|
||||
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
|
||||
|
||||
// Move scale windows
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -452,6 +471,8 @@ struct CShuffleEpilogue
|
||||
// Optional scales (must share the same distribution to match per-thread indexing)
|
||||
constexpr bool has_scales =
|
||||
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
|
||||
constexpr bool has_scalar_scales =
|
||||
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
|
||||
|
||||
// Tiles to hold row/col scales when present
|
||||
using SMType = typename GetDataType<remove_cvref_t<ScaleM>>::type;
|
||||
@@ -462,8 +483,11 @@ struct CShuffleEpilogue
|
||||
|
||||
// Build windows only if scales are provided
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
|
||||
"ScaleM must be a loadable tile");
|
||||
return make_tile_window(scale_m, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
@@ -472,8 +496,11 @@ struct CShuffleEpilogue
|
||||
}
|
||||
}();
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
|
||||
"ScaleN must be a loadable tile");
|
||||
return make_tile_window(scale_n, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
@@ -489,7 +516,7 @@ struct CShuffleEpilogue
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
// If scales provided, load them with identical distribution
|
||||
if constexpr(has_scales)
|
||||
if constexpr(has_scales && IsLoadableTile<ScaleM> && IsLoadableTile<ScaleN>)
|
||||
{
|
||||
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
|
||||
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
|
||||
@@ -504,7 +531,11 @@ struct CShuffleEpilogue
|
||||
auto emit = [&](index_t out_idx, index_t src_row) {
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
|
||||
|
||||
if constexpr(has_scales)
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
v = static_cast<AccDataType>(v * scale_m * scale_n);
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
{
|
||||
// same linear index mapping on the permuted distribution
|
||||
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
|
||||
@@ -595,10 +626,19 @@ struct CShuffleEpilogue
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr bool has_scales =
|
||||
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
|
||||
!std::is_same_v<ScaleM, EmptyScale> && !std::is_same_v<ScaleN, EmptyScale>;
|
||||
constexpr bool has_scalar_scales =
|
||||
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
return scale_m;
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
|
||||
"ScaleM must be a loadable tile");
|
||||
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
@@ -607,8 +647,15 @@ struct CShuffleEpilogue
|
||||
}
|
||||
}();
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
return scale_n;
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
|
||||
"ScaleN must be a loadable tile");
|
||||
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
21
include/ck_tile/ops/gemm_quant.hpp
Normal file
21
include/ck_tile/ops/gemm_quant.hpp
Normal file
@@ -0,0 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -330,7 +330,6 @@ struct QuantGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: no kernel currently uses BQuant like this:
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
@@ -890,6 +889,7 @@ struct QuantGemmKernel
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param aq_ptr input AQ pointer
|
||||
* @param bq_ptr input BQ pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
@@ -938,7 +938,8 @@ struct QuantGemmKernel
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant ||
|
||||
kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
@@ -964,6 +965,18 @@ struct QuantGemmKernel
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
// TODO: why doesn't readfirstlane work here?
|
||||
// const AccDataType aq_scale =
|
||||
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*aq_ptr));
|
||||
// const AccDataType bq_scale =
|
||||
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*bq_ptr));
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -168,17 +168,18 @@ template <typename ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
using GemmRowColQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
AccDataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
1, // no group size applicable
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
using GemmRowColTensorQuantPipelineProblem =
|
||||
GemmQuantPipelineProblemBase<ADataType_,
|
||||
AccDataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
1, // no group size applicable
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
} // namespace ck_tile
|
||||
@@ -12,9 +12,22 @@ enum struct QuantType : std::uint16_t
|
||||
{
|
||||
AQuantGrouped = 0,
|
||||
BQuantGrouped = 1,
|
||||
RowColQuant = 2
|
||||
RowColQuant = 2,
|
||||
TensorQuant = 3
|
||||
};
|
||||
|
||||
std::string quant_type_to_string(QuantType quant_type)
|
||||
{
|
||||
switch(quant_type)
|
||||
{
|
||||
case QuantType::AQuantGrouped: return "AQuantGrouped";
|
||||
case QuantType::BQuantGrouped: return "BQuantGrouped";
|
||||
case QuantType::RowColQuant: return "RowColQuant";
|
||||
case QuantType::TensorQuant: return "TensorQuant";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
@@ -41,8 +41,8 @@ TEST_F(CShuffleEpilogueTest, BasicHalfTest)
|
||||
NPerXdl,
|
||||
KPerXdl>;
|
||||
|
||||
bool result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>();
|
||||
EXPECT_TRUE(result) << "Basic CShuffleEpilogue test failed";
|
||||
auto result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::None);
|
||||
EXPECT_FLOAT_EQ(result[0], 2.0F) << "Basic CShuffleEpilogue test failed";
|
||||
}
|
||||
|
||||
TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)
|
||||
@@ -73,8 +73,45 @@ TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)
|
||||
NPerXdl,
|
||||
KPerXdl>;
|
||||
|
||||
bool result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(true);
|
||||
EXPECT_TRUE(result) << "Scale CShuffleEpilogue test failed";
|
||||
auto result =
|
||||
run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::RowCol);
|
||||
EXPECT_FLOAT_EQ(result[0], 2.0F) << "RowCol CShuffleEpilogue test failed: first element not 2";
|
||||
EXPECT_FLOAT_EQ(result[1], 4.0F)
|
||||
<< "RowCol CShuffleEpilogue test failed: second element not 2*2";
|
||||
}
|
||||
|
||||
TEST_F(CShuffleEpilogueTest, BasicHalfTestWithTensorScale)
|
||||
{
|
||||
// Basic test configuration with half_t data types
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::half_t;
|
||||
|
||||
constexpr index_t kMPerBlock = 256;
|
||||
constexpr index_t kNPerBlock = 256;
|
||||
constexpr index_t MWave = 2;
|
||||
constexpr index_t NWave = 2;
|
||||
constexpr index_t MPerXdl = 32;
|
||||
constexpr index_t NPerXdl = 32;
|
||||
constexpr index_t KPerXdl = 8;
|
||||
|
||||
using TestProblem = SimpleCShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
ODataType,
|
||||
kMPerBlock,
|
||||
kNPerBlock,
|
||||
MWave,
|
||||
NWave,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl>;
|
||||
|
||||
auto result =
|
||||
run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::Tensor);
|
||||
EXPECT_FLOAT_EQ(result[0], 4.0F)
|
||||
<< "TensorScale CShuffleEpilogue test failed: first element not 2*2=4";
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
|
||||
@@ -19,8 +19,15 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class ScaleType
|
||||
{
|
||||
None,
|
||||
RowCol,
|
||||
Tensor
|
||||
};
|
||||
|
||||
// Simple test kernel to invoke the CShuffleEpilogue
|
||||
template <typename Problem, index_t M, index_t N, bool UseScale>
|
||||
template <typename Problem, index_t M, index_t N, ScaleType Scale>
|
||||
__global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __restrict__ output_data,
|
||||
float* m_scale,
|
||||
float* n_scale)
|
||||
@@ -61,7 +68,7 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res
|
||||
auto empty_ds = make_tuple();
|
||||
|
||||
// Call the epilogue
|
||||
if constexpr(UseScale)
|
||||
if constexpr(Scale == ScaleType::RowCol)
|
||||
{
|
||||
const auto m_scale_window = make_tile_window(
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -75,6 +82,10 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res
|
||||
{0, 0});
|
||||
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem, m_scale_window, n_scale_window);
|
||||
}
|
||||
else if constexpr(Scale == ScaleType::Tensor)
|
||||
{
|
||||
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem, *m_scale, *n_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem);
|
||||
@@ -113,7 +124,7 @@ using SimpleCShuffleEpilogueProblem =
|
||||
memory_operation_enum::set>;
|
||||
|
||||
template <typename Problem, index_t M, index_t N>
|
||||
bool run_cshuffle_epilogue_test(bool use_scale = false)
|
||||
auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None)
|
||||
{
|
||||
using ODataType = typename Problem::ODataType;
|
||||
|
||||
@@ -142,7 +153,7 @@ bool run_cshuffle_epilogue_test(bool use_scale = false)
|
||||
dim3 gridSize(1, 1, 1);
|
||||
dim3 blockSize(kBlockSize, 1, 1);
|
||||
|
||||
if(use_scale)
|
||||
if(scale == ScaleType::RowCol)
|
||||
{
|
||||
float* m_scale;
|
||||
float* n_scale;
|
||||
@@ -155,12 +166,25 @@ bool run_cshuffle_epilogue_test(bool use_scale = false)
|
||||
hipMemcpy(m_scale, h_m_scale.data(), M * sizeof(float), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(n_scale, h_n_scale.data(), N * sizeof(float), hipMemcpyHostToDevice));
|
||||
test_cshuffle_epilogue_kernel<Problem, M, N, true>
|
||||
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::RowCol>
|
||||
<<<gridSize, blockSize>>>(device_output, m_scale, n_scale);
|
||||
}
|
||||
else if(scale == ScaleType::Tensor)
|
||||
{
|
||||
float* m_scale;
|
||||
float* n_scale;
|
||||
std::vector<float> h_m_scale(1, 2.0F);
|
||||
std::vector<float> h_n_scale(1, 1.0F);
|
||||
HIP_CHECK_ERROR(hipMalloc(&m_scale, sizeof(float)));
|
||||
HIP_CHECK_ERROR(hipMalloc(&n_scale, sizeof(float)));
|
||||
HIP_CHECK_ERROR(hipMemcpy(m_scale, h_m_scale.data(), sizeof(float), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(n_scale, h_n_scale.data(), sizeof(float), hipMemcpyHostToDevice));
|
||||
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::Tensor>
|
||||
<<<gridSize, blockSize>>>(device_output, m_scale, n_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
test_cshuffle_epilogue_kernel<Problem, M, N, false>
|
||||
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::None>
|
||||
<<<gridSize, blockSize>>>(device_output, nullptr, nullptr);
|
||||
}
|
||||
|
||||
@@ -172,20 +196,10 @@ bool run_cshuffle_epilogue_test(bool use_scale = false)
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
host_output.data(), device_output, output_size * sizeof(ODataType), hipMemcpyDeviceToHost));
|
||||
|
||||
// Basic verification - just check that output has a 2, and 4 if using scaling
|
||||
bool has_2 =
|
||||
type_convert<float>(host_output[0]) > 1.9F && type_convert<float>(host_output[0]) < 2.1F;
|
||||
bool scale_has_4 = true;
|
||||
if(use_scale)
|
||||
{
|
||||
scale_has_4 = type_convert<float>(host_output[1]) > 3.9F &&
|
||||
type_convert<float>(host_output[1]) < 4.1F;
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
HIP_CHECK_ERROR(hipFree(device_output));
|
||||
|
||||
return has_2 && scale_has_4;
|
||||
return host_output;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,14 +6,9 @@ endif()
|
||||
list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
set(TEST_GEMM_NAME test_tile_gemm_aquant_basic)
|
||||
set(QUANT_TYPES fp8 bf8 i4fp8 i4bf8 i4f32fp8 i4f32bf8)
|
||||
|
||||
foreach(QUANT_TYPE ${QUANT_TYPES})
|
||||
add_gtest_executable(${TEST_GEMM_NAME}_${QUANT_TYPE} test_gemm_aquant_basic_${QUANT_TYPE}.cpp)
|
||||
target_compile_options(${TEST_GEMM_NAME}_${QUANT_TYPE} PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
endforeach()
|
||||
|
||||
# Typed Test Suite for GEMM Quantization
|
||||
add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp)
|
||||
target_compile_options(test_tile_gemm_quant_typed PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_run_gemm_aquant_example.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("bf8"); }
|
||||
@@ -1,6 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_run_gemm_aquant_example.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("fp8"); }
|
||||
@@ -1,6 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_run_gemm_aquant_example.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("i4bf8"); }
|
||||
@@ -1,6 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_run_gemm_aquant_example.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("i4f32bf8"); }
|
||||
@@ -1,6 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_run_gemm_aquant_example.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("i4f32fp8"); }
|
||||
@@ -1,6 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_run_gemm_aquant_example.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("i4fp8"); }
|
||||
@@ -1,243 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_PREFILL 1
|
||||
#define CK_TILE_PIPELINE_DECODE 2
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
class ArgumentsNotSupportedException : public std::logic_error
|
||||
{
|
||||
public:
|
||||
explicit ArgumentsNotSupportedException(const std::string& message) : logic_error(message) {}
|
||||
};
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigDecode : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_ = ADataType_,
|
||||
typename CDataType_ = ADataType_,
|
||||
typename QDataType_ = float>
|
||||
struct GemmQuantTypeConfig
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using QDataType = QDataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = float;
|
||||
using CDataType = CDataType_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("aq_layout", "R", "Aq tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Column by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_q", "0", "Tensor AQ stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("persistent", "0", "0:non-persistent, 1:persistent")
|
||||
.insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
179
test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp
Normal file
179
test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp
Normal file
@@ -0,0 +1,179 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <stdexcept>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/check_err.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
|
||||
// Forward declarations for quant type-specific implementations
|
||||
template <ck_tile::QuantType QT>
|
||||
struct QuantTypeTraits;
|
||||
|
||||
// Base class for common quant gemm functionality
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using QDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto QuantType = std::tuple_element_t<7, Tuple>::value;
|
||||
using GemmConfig = std::tuple_element_t<8, Tuple>;
|
||||
static constexpr uint32_t QuantGroupSize = std::tuple_element_t<9, Tuple>::value;
|
||||
using AccDataType = float; // accumulate always in float
|
||||
|
||||
// Get the quant-type specific data types from traits
|
||||
using QuantTraits = QuantTypeTraits<QuantType>;
|
||||
using ComputeDataType = typename QuantTraits::template ComputeDataType<ADataType, BDataType>;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
static constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
static constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
static constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
static constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
public:
|
||||
void SetUp() override { static_cast<Derived*>(this)->SetUpQuantTypeSpecific(); }
|
||||
|
||||
void TearDown() override { static_cast<Derived*>(this)->TearDownQuantTypeSpecific(); }
|
||||
|
||||
// Common test execution logic
|
||||
void invoke_quant_gemm(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
constexpr bool kPreshuffle = false;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
kPreshuffle,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantType>;
|
||||
|
||||
// Let the derived class create the appropriate pipeline and epilogue
|
||||
static_cast<Derived*>(this)
|
||||
->template run_quant_gemm_impl<CodegenGemmShape, TilePartitioner, CodegenGemmTraits>(
|
||||
args, s);
|
||||
}
|
||||
|
||||
void RunTest(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
// Generate test data and run the kernel
|
||||
static_cast<Derived*>(this)->run_test_with_validation(M, N, K);
|
||||
}
|
||||
|
||||
// Helper function to check layout
|
||||
template <typename Layout>
|
||||
static constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(Layout{})>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Tolerance calculation function for validation
|
||||
template <typename ADataType_, typename BDataType_, typename AccDataType_, typename CDataType_>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType_, AccDataType_>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType_, AccDataType_>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType_, CDataType_, CDataType_>(kbatch);
|
||||
const auto atol_split_k =
|
||||
ck_tile::get_absolute_threshold<CDataType_, CDataType_, CDataType_>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
};
|
||||
|
||||
// Define generic QuantTypeTraits template (will be specialized)
|
||||
template <ck_tile::QuantType QT>
|
||||
struct QuantTypeTraits
|
||||
{
|
||||
static_assert(QT == ck_tile::QuantType::AQuantGrouped ||
|
||||
QT == ck_tile::QuantType::BQuantGrouped ||
|
||||
QT == ck_tile::QuantType::RowColQuant ||
|
||||
QT == ck_tile::QuantType::TensorQuant,
|
||||
"Unsupported quantization type");
|
||||
};
|
||||
|
||||
// Specialization for AQuantGrouped
|
||||
template <>
|
||||
struct QuantTypeTraits<ck_tile::QuantType::AQuantGrouped>
|
||||
{
|
||||
template <typename ADataType, typename BDataType>
|
||||
using ComputeDataType = BDataType; // For AQuant, compute type is BDataType
|
||||
|
||||
static constexpr const char* name = "aquant";
|
||||
};
|
||||
|
||||
// Specialization for BQuantGrouped
|
||||
template <>
|
||||
struct QuantTypeTraits<ck_tile::QuantType::BQuantGrouped>
|
||||
{
|
||||
template <typename ADataType, typename BDataType>
|
||||
using ComputeDataType = ADataType; // For BQuant, compute type is ADataType
|
||||
|
||||
static constexpr const char* name = "bquant";
|
||||
};
|
||||
|
||||
// Specialization for RowColQuant
|
||||
template <>
|
||||
struct QuantTypeTraits<ck_tile::QuantType::RowColQuant>
|
||||
{
|
||||
template <typename ADataType, typename BDataType>
|
||||
using ComputeDataType = ADataType; // For RowColQuant, compute type is ADataType
|
||||
|
||||
static constexpr const char* name = "rowcol";
|
||||
};
|
||||
|
||||
// Specialization for TensorQuant
|
||||
template <>
|
||||
struct QuantTypeTraits<ck_tile::QuantType::TensorQuant>
|
||||
{
|
||||
template <typename ADataType, typename BDataType>
|
||||
using ComputeDataType = ADataType; // For TensorQuant, compute type is ADataType
|
||||
|
||||
static constexpr const char* name = "tensor";
|
||||
};
|
||||
919
test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp
Normal file
919
test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp
Normal file
@@ -0,0 +1,919 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "test_gemm_quant_base.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
// Default GEMM tile sizes for tests
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
|
||||
{
|
||||
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>;
|
||||
friend Base;
|
||||
|
||||
public:
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::CDataType;
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
// AQuant-specific data generation
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
|
||||
// AQuant uses grouped quantization for A matrix
|
||||
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize);
|
||||
const ck_tile::index_t stride_AQ =
|
||||
ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(ALayout{}));
|
||||
|
||||
// Generate test data
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<QDataType> aq_m_aqk(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
|
||||
// Initialize data with random values
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f}(a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f}(a_m_k);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{-2.0f, 2.0f}(aq_m_aqk);
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size() * sizeof(QDataType));
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Copy to device
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<ADataType> temp = a_m_k;
|
||||
ck_tile::permute_vectors_i4x4_b(temp);
|
||||
a_m_k_dev_buf.ToDevice(temp.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales)
|
||||
nullptr, // bq_ptr (not used for AQuant)
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
AQK, // QK_A
|
||||
0, // QK_B (not used for AQuant)
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
stride_AQ,
|
||||
0 // strides
|
||||
};
|
||||
|
||||
// Run the kernel
|
||||
ck_tile::stream_config stream_config{};
|
||||
this->invoke_quant_gemm(args, stream_config);
|
||||
|
||||
// Validation using reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Run reference AQuant implementation
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
|
||||
|
||||
// Get device result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
|
||||
|
||||
// Calculate error tolerances
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "AQuantGrouped validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "AQuantGrouped - Relative error threshold: "
|
||||
<< rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// AQuant-specific pipeline implementation
|
||||
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
|
||||
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
QuantGroupSize,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Arguments not supported for AQuant kernel");
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
};
|
||||
|
||||
// BQuant-specific test fixture
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmBQuant<Tuple>>
|
||||
{
|
||||
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmBQuant<Tuple>>;
|
||||
friend Base;
|
||||
|
||||
public:
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::CDataType;
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
|
||||
// BQuant uses grouped quantization for B matrix
|
||||
const ck_tile::index_t BQK = ck_tile::integer_divide_ceil(K, QuantGroupSize);
|
||||
const ck_tile::index_t stride_BQ = BQK;
|
||||
|
||||
// Generate test data
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> bq_bqk_n(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, this->is_row_major(BLayout{})));
|
||||
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(bq_bqk_n);
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
ck_tile::DeviceMem bq_bqk_n_dev_buf(bq_bqk_n.get_element_space_size() * sizeof(QDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Copy to device
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> temp = b_k_n;
|
||||
ck_tile::permute_vectors_i4x4_b(temp);
|
||||
b_k_n_dev_buf.ToDevice(temp.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
nullptr, // aq_ptr (not used for BQuant)
|
||||
bq_bqk_n_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
0, // QK_A (not used for BQuant)
|
||||
BQK, // QK_B
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
0,
|
||||
stride_BQ // strides
|
||||
};
|
||||
|
||||
// Run the kernel
|
||||
ck_tile::stream_config stream_config{};
|
||||
this->invoke_quant_gemm(args, stream_config);
|
||||
|
||||
// Validation using reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Run reference BQuant implementation
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, bq_bqk_n, b_k_n, c_m_n_host_ref);
|
||||
|
||||
// Get device result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
|
||||
|
||||
// Calculate error tolerances
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "BQuantGrouped - Relative error threshold: "
|
||||
<< rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// BQuant-specific pipeline implementation
|
||||
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
|
||||
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
QDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
QuantGroupSize,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
false, // transpose_c
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Arguments not supported for BQuant kernel");
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
};
|
||||
|
||||
// RowColQuant-specific test fixture
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmRowColQuant
|
||||
: public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmRowColQuant<Tuple>>
|
||||
{
|
||||
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmRowColQuant<Tuple>>;
|
||||
friend Base;
|
||||
|
||||
public:
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::CDataType;
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
|
||||
// RowColQuant uses per-row and per-column scales
|
||||
const ck_tile::index_t stride_row_scales = 1;
|
||||
const ck_tile::index_t stride_col_scales = 1;
|
||||
|
||||
// Generate test data
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> row_scales_m(ck_tile::host_tensor_descriptor(
|
||||
M, 1, stride_row_scales, ck_tile::bool_constant<true>{}));
|
||||
ck_tile::HostTensor<QDataType> col_scales_n(ck_tile::host_tensor_descriptor(
|
||||
N, 1, stride_col_scales, ck_tile::bool_constant<true>{}));
|
||||
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(row_scales_m);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(col_scales_n);
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
ck_tile::DeviceMem row_scales_dev_buf(row_scales_m.get_element_space_size() *
|
||||
sizeof(QDataType));
|
||||
ck_tile::DeviceMem col_scales_dev_buf(col_scales_n.get_element_space_size() *
|
||||
sizeof(QDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Copy to device
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
row_scales_dev_buf.ToDevice(row_scales_m.data());
|
||||
col_scales_dev_buf.ToDevice(col_scales_n.data());
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
row_scales_dev_buf.GetDeviceBuffer(), // aq_ptr (row scales)
|
||||
col_scales_dev_buf.GetDeviceBuffer(), // bq_ptr (col scales)
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
1, // QK_A (row scales)
|
||||
1, // QK_B (col scales)
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
stride_row_scales,
|
||||
stride_col_scales // strides
|
||||
};
|
||||
|
||||
// Run the kernel
|
||||
ck_tile::stream_config stream_config{};
|
||||
this->invoke_quant_gemm(args, stream_config);
|
||||
|
||||
// Validation using reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Run reference RowColQuant implementation
|
||||
ck_tile::reference_gemm_rowcol_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
QDataType,
|
||||
AccDataType,
|
||||
CDataType>(
|
||||
a_m_k, row_scales_m, b_k_n, col_scales_n, c_m_n_host_ref);
|
||||
|
||||
// Get device result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
|
||||
|
||||
// Calculate error tolerances
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "RowColQuant validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "RowColQuant - Relative error threshold: "
|
||||
<< rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// RowColQuant-specific pipeline implementation
|
||||
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
|
||||
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::RowColQuant>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Arguments not supported for RowColQuant kernel");
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
};
|
||||
|
||||
// TensorQuant-specific test fixture
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmTensorQuant
|
||||
: public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmTensorQuant<Tuple>>
|
||||
{
|
||||
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmTensorQuant<Tuple>>;
|
||||
friend Base;
|
||||
|
||||
public:
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::CDataType;
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
|
||||
// TensorQuant uses single scalar scale for each tensor
|
||||
const ck_tile::index_t stride_scale_a = 1;
|
||||
const ck_tile::index_t stride_scale_b = 1;
|
||||
|
||||
// Generate test data
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> scale_a(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_scale_a, ck_tile::bool_constant<true>{}));
|
||||
ck_tile::HostTensor<QDataType> scale_b(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_scale_b, ck_tile::bool_constant<true>{}));
|
||||
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(scale_b);
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a.get_element_space_size() * sizeof(QDataType));
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b.get_element_space_size() * sizeof(QDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Copy to device
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
scale_a_dev_buf.ToDevice(scale_a.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b.data());
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
scale_a_dev_buf.GetDeviceBuffer(), // aq_ptr (scale A)
|
||||
scale_b_dev_buf.GetDeviceBuffer(), // bq_ptr (scale B)
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
1, // QK_A (tensor scale)
|
||||
1, // QK_B (tensor scale)
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
stride_scale_a,
|
||||
stride_scale_b // strides
|
||||
};
|
||||
|
||||
// Run the kernel
|
||||
ck_tile::stream_config stream_config{};
|
||||
this->invoke_quant_gemm(args, stream_config);
|
||||
|
||||
// Validation using reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Run reference TensorQuant implementation
|
||||
ck_tile::reference_gemm_tensor_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
QDataType,
|
||||
AccDataType,
|
||||
CDataType>(
|
||||
a_m_k, scale_a, b_k_n, scale_b, c_m_n_host_ref);
|
||||
|
||||
// Get device result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
|
||||
|
||||
// Calculate error tolerances
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "TensorQuant validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "TensorQuant - Relative error threshold: "
|
||||
<< rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// TensorQuant-specific pipeline implementation
|
||||
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
|
||||
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::TensorQuant>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Arguments not supported for TensorQuant kernel");
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
};
|
||||
64
test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp
Normal file
64
test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp
Normal file
@@ -0,0 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
|
||||
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
|
||||
using GroupSize = std::integral_constant<unsigned int, 128>;
|
||||
|
||||
// Type combinations for each quantization type
|
||||
// clang-format off
|
||||
using AQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using BQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using RowColQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, RowColQuant, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, RowColQuant, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using TensorQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, TensorQuant, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, TensorQuant, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suites for each quantization type
|
||||
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes);
|
||||
|
||||
#include "test_gemm_quant_ut_cases.inc"
|
||||
28
test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc
Normal file
28
test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc
Normal file
@@ -0,0 +1,28 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
|
||||
// RowColQuant tests
|
||||
TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
|
||||
// TensorQuant tests
|
||||
TYPED_TEST(TestCkTileGemmTensorQuant, TensorQuantTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -1,616 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <random>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_aquant_utils.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ComputeDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
false, // preshuffle
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
constexpr bool transposed_warp_gemm = false;
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
QuantGroupSize,
|
||||
transposed_warp_gemm,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
transposed_warp_gemm,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
CodegenGemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t AQK,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_AQ,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::QuantGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.QK_A = AQK;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
args.stride_AQ = stride_AQ;
|
||||
|
||||
float ave_time = gemm_calc_aquant<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
BDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantGroupSize>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK +
|
||||
sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
|
||||
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
|
||||
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
|
||||
<< " A_Type = " << DataTypeTraits<ADataType>::name
|
||||
<< " AQ_Type = " << DataTypeTraits<AQDataType>::name
|
||||
<< " B_Type = " << DataTypeTraits<BDataType>::name
|
||||
<< " Acc_Type = " << DataTypeTraits<AccDataType>::name
|
||||
<< " C_Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
uint32_t QuantGroupSize,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
bool run_gemm_test_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const AQLayout aq_layout = AQLayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using AQDataType = typename TypeConfig::QDataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
using CDataType = typename TypeConfig::CDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
if(K % QuantGroupSize != 0)
|
||||
{
|
||||
throw std::runtime_error("K must be aligned with QuantGroupSize");
|
||||
}
|
||||
|
||||
ck_tile::index_t AQK = K / QuantGroupSize;
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<AQDataType> aq_m_aqk(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
std::cout << "Monotonic initialization is not supported." << std::endl;
|
||||
return true;
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(aq_m_aqk);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
aq_m_aqk.SetZero();
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
|
||||
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantGroupSize>(a_m_k_dev_buf,
|
||||
aq_m_aqk_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
AQK,
|
||||
stride_A,
|
||||
stride_AQ,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
|
||||
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for A.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
bool run_gemm_test(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4fp8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32fp8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int run_gemm_combinations(std::string const& data_type)
|
||||
{
|
||||
// Define possible values for each parameter
|
||||
std::vector<std::vector<std::string>> mnk_values = {{
|
||||
"1",
|
||||
"2048",
|
||||
"5120",
|
||||
},
|
||||
{
|
||||
"2",
|
||||
"2048",
|
||||
"5120",
|
||||
},
|
||||
{
|
||||
"16",
|
||||
"2048",
|
||||
"5120",
|
||||
},
|
||||
{
|
||||
"17",
|
||||
"2048",
|
||||
"5120",
|
||||
},
|
||||
{
|
||||
"2047",
|
||||
"5120",
|
||||
"1024",
|
||||
},
|
||||
{
|
||||
"2048",
|
||||
"5120",
|
||||
"1024",
|
||||
}};
|
||||
std::vector<std::string> prec_values = {data_type};
|
||||
|
||||
// We'll store all our arguments as strings first
|
||||
std::vector<std::string> arg_strings = {"test_tile_gemm_aquant_basic",
|
||||
"", // m placeholder
|
||||
"", // n placeholder
|
||||
"", // k placeholder
|
||||
"", // prec placeholder
|
||||
"-init=0",
|
||||
"-v=1",
|
||||
"-warmup=0",
|
||||
"-repeat=1"};
|
||||
|
||||
// Create an array of const char pointers for argv
|
||||
constexpr size_t ARG_COUNT = 9;
|
||||
constexpr size_t ARG_MAX_LEN = 64;
|
||||
char args[ARG_COUNT][ARG_MAX_LEN];
|
||||
char* argv[ARG_COUNT];
|
||||
|
||||
// Run all combinations
|
||||
bool is_success = true;
|
||||
for(const auto& mnk : mnk_values)
|
||||
{
|
||||
arg_strings[1] = "-m=" + mnk[0];
|
||||
arg_strings[2] = "-n=" + mnk[1];
|
||||
arg_strings[3] = "-k=" + mnk[2];
|
||||
|
||||
for(const auto& prec : prec_values)
|
||||
{
|
||||
arg_strings[4] = "-prec=" + prec;
|
||||
|
||||
// Set up the argv array with pointers to the string data
|
||||
for(size_t i = 0; i < ARG_COUNT; i++)
|
||||
{
|
||||
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
|
||||
argv[i] = args[i];
|
||||
}
|
||||
|
||||
std::cout << "Arguments received: ";
|
||||
for(size_t i = 1; i < ARG_COUNT; ++i)
|
||||
{
|
||||
std::cout << argv[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
is_success = run_gemm_test<GemmConfigDecode>(ARG_COUNT, argv) && is_success;
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
|
||||
// ArgumentsNotSupportedException is not an error. Do not change is_success
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
is_success = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
|
||||
}
|
||||
Reference in New Issue
Block a user