[CK_TILE] Tensor-wise scaled quant gemm kernel (#2846)

* rename gemm_group_quant to gemm_quant

* Add TensorWise quant mode

* Cshuffle epilogue tests with tensor scaling

* Add tensor quant to example

* Don't use readfirstlane for reading scales - doesn't work for some reason

* Add to changelog

* revert include - from a merge problem?

* revert common.hpp include

* revert host.hpp include

* remove unused utility function

* rename quant pipeline problem

* refactor quant tests

* remove aquant utils

* use TEST_F

* fix all tests by changing gemm config

* Use typed tests

* fix copyright

[ROCm/composable_kernel commit: 4363a82bd6]
This commit is contained in:
Sami Remes
2025-09-20 02:52:35 +03:00
committed by GitHub
parent ee43f0f0be
commit 8d2a444c55
39 changed files with 1555 additions and 1056 deletions

View File

@@ -180,10 +180,6 @@ CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
@@ -198,7 +194,57 @@ CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
std::cout << std::endl;
}
template <typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<AQDataType>& aq_1_1,
const HostTensor<BDataType>& b_k_n,
const HostTensor<BQDataType>& bq_1_1,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
static_assert(std::is_same_v<AccDataType, float>);
static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
// Init accumulator
AccDataType v_acc = 0;
// Get scale for A and scale for B
const AccDataType a_scale = ck_tile::type_convert<AccDataType>(aq_1_1(0, 0));
const AccDataType b_scale = ck_tile::type_convert<AccDataType>(bq_1_1(0, 0));
// Compute the dot product
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
AccDataType v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
v_acc += v_a * v_b;
}
v_acc = v_acc * a_scale * b_scale;
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,