mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Finalize cleanup
This commit is contained in:
@@ -825,13 +825,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<BQDataType, ck_tile::e8m0_t>)
|
||||
ck_tile::reference_mxfp4gemm_quant<ADataType,
|
||||
BQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
BQuantGroupSize,
|
||||
false>(
|
||||
ck_tile::reference_mx_gemm_bquant<ADataType,
|
||||
BQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
BQuantGroupSize,
|
||||
false>(
|
||||
a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
|
||||
else
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
|
||||
@@ -392,13 +392,13 @@ template <typename ADataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<QDataType>& q,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<QDataType>& q,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
|
||||
@@ -216,6 +216,7 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
using BTypeTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
BTypeTile b_warp_tile_lds_;
|
||||
|
||||
// Load from LDS (assumption is that the scale will be applied in the block gemm)
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
@@ -232,6 +233,7 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
b_warp_tile_, b_block_window);
|
||||
}
|
||||
|
||||
// Load from LDS and scale (then the tile can directly be consumed in the block gemm)
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
typename BQRegBlockTile,
|
||||
@@ -243,16 +245,34 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
|
||||
a_warp_tile_, a_block_window);
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
|
||||
b_warp_tile_lds_, b_block_window);
|
||||
// Load tile from LDS
|
||||
|
||||
// Apply scale
|
||||
using BDataTypeRaw = typename std::
|
||||
conditional<std::is_same_v<BDataType, pk_fp4_t>, pk_fp4_t::type, BDataType>::type;
|
||||
// Do not use load_int4_tile here because it will have support to cast from fp4 to
|
||||
// compute type, while here we want to only load from LDS and then apply the scale
|
||||
// and cast later
|
||||
if constexpr(ALoadTranspose)
|
||||
{
|
||||
a_warp_tile_ = load_tile_transpose(a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
if constexpr(BLoadTranspose)
|
||||
{
|
||||
b_warp_tile_lds_ = load_tile_transpose(b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_lds_, b_block_window);
|
||||
}
|
||||
|
||||
// Apply scale and cast
|
||||
using BDataTypeRaw =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_fp4_t>, pk_fp4_t::type, BDataType>;
|
||||
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
constexpr index_t nelements = WarpGemm::kK * WarpGemm::kN / warp_size;
|
||||
constexpr index_t thread_buffer_size = nelements / UnaryOpSize_;
|
||||
const element_wise::DequantPack8 elementwise_op{};
|
||||
@@ -262,6 +282,22 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
// B scale register offset
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN))
|
||||
return (nIter * NWarp * WarpGemm::kN) / GemmTraits::QuantGroupSize::kN *
|
||||
Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
return nIter * Traits::KQPerBlock + kQScale;
|
||||
}
|
||||
}();
|
||||
|
||||
// Get B scale from thread buffer
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float b_scale_f = float(scale_reg);
|
||||
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
// Thread buffers
|
||||
@@ -275,27 +311,12 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
BWarpThreadBuffer b_warp_thread_buffer;
|
||||
BLDSThreadBuffer b_lds_thread_buffer;
|
||||
|
||||
// BQuant register offset
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN))
|
||||
return (nIter * NWarp * WarpGemm::kN) /
|
||||
GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
return nIter * Traits::KQPerBlock + kQScale;
|
||||
}
|
||||
}();
|
||||
|
||||
// Load thread buffer from tile (LDS type)
|
||||
b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// Apply scale to thread buffer and cast
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float b_scale_f = float(scale_reg);
|
||||
|
||||
// Apply scale to B thread buffer and cast
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(
|
||||
b_warp_thread_buffer.template get_as<DstVectorType>()(i),
|
||||
@@ -303,7 +324,7 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
b_scale_f);
|
||||
});
|
||||
|
||||
// Store thread buffer to tile (MMA type)
|
||||
// Store B thread buffer to tile (MMA type)
|
||||
b_warp_tile_.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths),
|
||||
|
||||
@@ -70,6 +70,10 @@ struct GemmMxPipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution()
|
||||
{
|
||||
// If we apply scale before writing to LDS, we need a tile distribution for
|
||||
// BQuant consistent with global memory reading of matrix B, while
|
||||
// if we apply scale after reading from LDS, we need a tile distribution for
|
||||
// BQuant consistent with the MMA instructions layout
|
||||
if constexpr(Problem::BCastPolicy == CastPolicy::AfterLDSRead)
|
||||
{
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
|
||||
@@ -787,13 +787,13 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
|
||||
// Run reference BQuant implementation
|
||||
if constexpr(std::is_same_v<QDataType, ck_tile::e8m0_t>)
|
||||
ck_tile::reference_mxfp4gemm_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
|
||||
ck_tile::reference_mx_gemm_bquant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
|
||||
else
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
QDataType,
|
||||
|
||||
Reference in New Issue
Block a user