[CK_TILE] Tensor-wise scaled quant gemm kernel (#2846)

* rename gemm_group_quant to gemm_quant

* Add TensorWise quant mode

* Cshuffle epilogue tests with tensor scaling

* Add tensor quant to example

* Don't use readfirstlane for reading scales - doesn't work for some reason

* Add to changelog

* revert include - from a merge problem?

* revert common.hpp include

* revert host.hpp include

* remove unused utility function

* rename quant pipeline problem

* refactor quant tests

* remove aquant utils

* use TEST_F

* fix all tests by changing gemm config

* Use typed tests

* fix copyright

[ROCm/composable_kernel commit: 4363a82bd6]
This commit is contained in:
Sami Remes
2025-09-20 02:52:35 +03:00
committed by GitHub
parent ee43f0f0be
commit 8d2a444c55
39 changed files with 1555 additions and 1056 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -158,4 +158,7 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<Windo
{
}
template <typename Tile>
concept IsLoadableTile = requires { load_tile(std::declval<Tile>()); };
} // namespace ck_tile

View File

@@ -180,10 +180,6 @@ CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
@@ -198,7 +194,57 @@ CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
std::cout << std::endl;
}
template <typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<AQDataType>& aq_1_1,
const HostTensor<BDataType>& b_k_n,
const HostTensor<BQDataType>& bq_1_1,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
static_assert(std::is_same_v<AccDataType, float>);
static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
// Init accumulator
AccDataType v_acc = 0;
// Get scale for A and scale for B
const AccDataType a_scale = ck_tile::type_convert<AccDataType>(aq_1_1(0, 0));
const AccDataType b_scale = ck_tile::type_convert<AccDataType>(bq_1_1(0, 0));
// Compute the dot product
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
AccDataType v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
v_acc += v_a * v_b;
}
v_acc = v_acc * a_scale * b_scale;
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,

View File

@@ -304,22 +304,41 @@ struct CShuffleEpilogue
CK_TILE_DEVICE void
scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
{
// Load tiles
const auto scale_m_tile = load_tile(scale_m_window);
const auto scale_n_tile = load_tile(scale_n_window);
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
tile_elementwise_inout(
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
// Move scale windows
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
// Check if scales are EmptyScale first (no scaling needed)
if constexpr(std::is_same_v<ScaleM, EmptyScale> && std::is_same_v<ScaleN, EmptyScale>)
{
constexpr auto step = SFC::get_forward_step(iAccess);
// No scaling needed - this is a no-op
}
// Check if scales are scalar AccDataType
else if constexpr(std::is_same_v<ScaleM, AccDataType> &&
std::is_same_v<ScaleN, AccDataType>)
{
// Handle scalar scales
const AccDataType scale_m = scale_m_window;
const AccDataType scale_n = scale_n_window;
tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; },
lds_tile);
}
// Otherwise, assume they are tile windows that can be loaded
else
{
// Load tiles
const auto scale_m_tile = load_tile(scale_m_window);
const auto scale_n_tile = load_tile(scale_n_window);
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
tile_elementwise_inout(
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
// Move scale windows
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
}
}
}
@@ -452,6 +471,8 @@ struct CShuffleEpilogue
// Optional scales (must share the same distribution to match per-thread indexing)
constexpr bool has_scales =
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
constexpr bool has_scalar_scales =
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
// Tiles to hold row/col scales when present
using SMType = typename GetDataType<remove_cvref_t<ScaleM>>::type;
@@ -462,8 +483,11 @@ struct CShuffleEpilogue
// Build windows only if scales are provided
auto scale_m_window = [&]() {
if constexpr(has_scales)
if constexpr(has_scales && !has_scalar_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
"ScaleM must be a loadable tile");
return make_tile_window(scale_m, dram_tile_distribution);
}
else
@@ -472,8 +496,11 @@ struct CShuffleEpilogue
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scales)
if constexpr(has_scales && !has_scalar_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
"ScaleN must be a loadable tile");
return make_tile_window(scale_n, dram_tile_distribution);
}
else
@@ -489,7 +516,7 @@ struct CShuffleEpilogue
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
// If scales provided, load them with identical distribution
if constexpr(has_scales)
if constexpr(has_scales && IsLoadableTile<ScaleM> && IsLoadableTile<ScaleN>)
{
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
@@ -504,7 +531,11 @@ struct CShuffleEpilogue
auto emit = [&](index_t out_idx, index_t src_row) {
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
if constexpr(has_scales)
if constexpr(has_scalar_scales)
{
v = static_cast<AccDataType>(v * scale_m * scale_n);
}
else if constexpr(has_scales)
{
// same linear index mapping on the permuted distribution
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
@@ -595,10 +626,19 @@ struct CShuffleEpilogue
number<NumDTensor>{});
constexpr bool has_scales =
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
!std::is_same_v<ScaleM, EmptyScale> && !std::is_same_v<ScaleN, EmptyScale>;
constexpr bool has_scalar_scales =
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
auto scale_m_window = [&]() {
if constexpr(has_scales)
if constexpr(has_scalar_scales)
{
return scale_m;
}
else if constexpr(has_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
"ScaleM must be a loadable tile");
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
}
else
@@ -607,8 +647,15 @@ struct CShuffleEpilogue
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scales)
if constexpr(has_scalar_scales)
{
return scale_n;
}
else if constexpr(has_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
"ScaleN must be a loadable tile");
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
}
else

View File

@@ -1,21 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -0,0 +1,21 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -12,7 +12,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
namespace ck_tile {
@@ -330,7 +330,6 @@ struct QuantGemmKernel
}
}
// NOTE: no kernel currently uses BQuant like this:
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
@@ -890,6 +889,7 @@ struct QuantGemmKernel
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param aq_ptr input AQ pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
@@ -938,7 +938,8 @@ struct QuantGemmKernel
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
else if constexpr(kQuantType == QuantType::RowColQuant ||
kQuantType == QuantType::TensorQuant)
{
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
@@ -964,6 +965,18 @@ struct QuantGemmKernel
aq_block_window,
bq_block_window);
}
else if constexpr(kQuantType == QuantType::TensorQuant)
{
// TODO: why doesn't readfirstlane work here?
// const AccDataType aq_scale =
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*aq_ptr));
// const AccDataType bq_scale =
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*bq_ptr));
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
}
}
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const

View File

@@ -9,7 +9,7 @@
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/host.hpp"
#include <hip/hip_runtime.h>

View File

@@ -9,7 +9,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {

View File

@@ -9,7 +9,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {

View File

@@ -168,17 +168,18 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
using GemmRowColQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
AccDataType_,
BDataType_,
AccDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
1, // no group size applicable
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
using GemmRowColTensorQuantPipelineProblem =
GemmQuantPipelineProblemBase<ADataType_,
AccDataType_,
BDataType_,
AccDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
1, // no group size applicable
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
} // namespace ck_tile

View File

@@ -12,9 +12,22 @@ enum struct QuantType : std::uint16_t
{
AQuantGrouped = 0,
BQuantGrouped = 1,
RowColQuant = 2
RowColQuant = 2,
TensorQuant = 3
};
std::string quant_type_to_string(QuantType quant_type)
{
switch(quant_type)
{
case QuantType::AQuantGrouped: return "AQuantGrouped";
case QuantType::BQuantGrouped: return "BQuantGrouped";
case QuantType::RowColQuant: return "RowColQuant";
case QuantType::TensorQuant: return "TensorQuant";
default: return "Unknown";
}
}
template <bool kPadM_,
bool kPadN_,
bool kPadK_,