Finalize cleanup

This commit is contained in:
Enrico Degregori
2026-01-30 11:27:27 +00:00
parent afd3d2bd10
commit b6a66c19e8
5 changed files with 71 additions and 46 deletions

View File

@@ -825,13 +825,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if constexpr(std::is_same_v<BQDataType, ck_tile::e8m0_t>)
ck_tile::reference_mxfp4gemm_quant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
BQuantGroupSize,
false>(
ck_tile::reference_mx_gemm_bquant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
BQuantGroupSize,
false>(
a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
else
ck_tile::reference_gemm_quant<ADataType,

View File

@@ -392,13 +392,13 @@ template <typename ADataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<QDataType>& q,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor<ADataType>& a_m_k,
const HostTensor<QDataType>& q,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);

View File

@@ -216,6 +216,7 @@ struct BQuantBlockUniversalGemmAsBsCr
using BTypeTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
BTypeTile b_warp_tile_lds_;
// Load from LDS (assumption is that the scale will be applied in the block gemm)
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
@@ -232,6 +233,7 @@ struct BQuantBlockUniversalGemmAsBsCr
b_warp_tile_, b_block_window);
}
// Load from LDS and scale (then the tile can directly be consumed in the block gemm)
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
typename BQRegBlockTile,
@@ -243,16 +245,34 @@ struct BQuantBlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_lds_, b_block_window);
// Load tile from LDS
// Apply scale
using BDataTypeRaw = typename std::
conditional<std::is_same_v<BDataType, pk_fp4_t>, pk_fp4_t::type, BDataType>::type;
// Do not use load_int4_tile here because it will have support to cast from fp4 to
// compute type, while here we want to only load from LDS and then apply the scale
// and cast later
if constexpr(ALoadTranspose)
{
a_warp_tile_ = load_tile_transpose(a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
constexpr auto warp_size = get_warp_size();
if constexpr(BLoadTranspose)
{
b_warp_tile_lds_ = load_tile_transpose(b_block_window);
}
else
{
load_tile(b_warp_tile_lds_, b_block_window);
}
// Apply scale and cast
using BDataTypeRaw =
std::conditional_t<std::is_same_v<BDataType, pk_fp4_t>, pk_fp4_t::type, BDataType>;
constexpr index_t warp_size = get_warp_size();
constexpr index_t nelements = WarpGemm::kK * WarpGemm::kN / warp_size;
constexpr index_t thread_buffer_size = nelements / UnaryOpSize_;
const element_wise::DequantPack8 elementwise_op{};
@@ -262,6 +282,22 @@ struct BQuantBlockUniversalGemmAsBsCr
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
// B scale register offset
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) / GemmTraits::QuantGroupSize::kN *
Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
// Get B scale from thread buffer
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_f = float(scale_reg);
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
// Thread buffers
@@ -275,27 +311,12 @@ struct BQuantBlockUniversalGemmAsBsCr
BWarpThreadBuffer b_warp_thread_buffer;
BLDSThreadBuffer b_lds_thread_buffer;
// BQuant register offset
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
// Load thread buffer from tile (LDS type)
b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// Apply scale to thread buffer and cast
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_f = float(scale_reg);
// Apply scale to B thread buffer and cast
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(
b_warp_thread_buffer.template get_as<DstVectorType>()(i),
@@ -303,7 +324,7 @@ struct BQuantBlockUniversalGemmAsBsCr
b_scale_f);
});
// Store thread buffer to tile (MMA type)
// Store B thread buffer to tile (MMA type)
b_warp_tile_.set_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths),

View File

@@ -70,6 +70,10 @@ struct GemmMxPipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution()
{
// If we apply scale before writing to LDS, we need a tile distribution for
// BQuant consistent with global memory reading of matrix B, while
// if we apply scale after reading from LDS, we need a tile distribution for
// BQuant consistent with the MMA instructions layout
if constexpr(Problem::BCastPolicy == CastPolicy::AfterLDSRead)
{
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;

View File

@@ -787,13 +787,13 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
// Run reference BQuant implementation
if constexpr(std::is_same_v<QDataType, ck_tile::e8m0_t>)
ck_tile::reference_mxfp4gemm_quant<ADataType,
QDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
ck_tile::reference_mx_gemm_bquant<ADataType,
QDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
else
ck_tile::reference_gemm_quant<ADataType,
QDataType,