mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add gemm weight preshuffle pk_int_t support (#2858)
* Factor out the three separate copies of load_interleaved_pk_type into a common utility class * Add preprocessing with optional cache flushing and clearing of output for k_batch > 1 to the weight preshuffle GEMM example * Remove a duplicate function * Add support for B tensor type pk_int4_t for the weight preshuffle GEMM, with tests included * I4 support introduced more failing test cases that mirror the existing ones for F8 * Simplify the check for which tests to skip (they all have F8 as A tensor type) * Add a changelog entry * add the test for v2 wp pipeline, polish the code, add the support of int4 for v2 wp pipeline * have a workable version for atomic add * Revert "have a workable version for atomic add" This reverts commit 792377a590c26cfff9c8f545d9a9e8484a7422eb. --------- Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -5,19 +5,19 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, index_t UnaryOpSize_ = 8>
|
||||
template <typename Problem>
|
||||
struct BlockGemmAQuantBase
|
||||
{
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
|
||||
static constexpr index_t UnaryOpSize = UnaryOpSize_;
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale)
|
||||
{
|
||||
@@ -42,23 +42,6 @@ struct BlockGemmAQuantBase
|
||||
}
|
||||
return scale_reg_f;
|
||||
}
|
||||
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
|
||||
const WarpWindow& warp_window)
|
||||
{
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
|
||||
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// A is block window on shared memory
|
||||
@@ -66,7 +49,9 @@ struct BlockGemmAQuantBase
|
||||
// Consecutive kQuantGroupSize elements of A are quantized with a separate scale.
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
|
||||
index_t UnaryOpSize_ = 8>
|
||||
struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
{
|
||||
private:
|
||||
@@ -172,6 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
|
||||
using Base = BlockGemmAQuantBase<Problem_>;
|
||||
|
||||
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
@@ -292,7 +278,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Base::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -302,7 +288,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Base::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -5,19 +5,19 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, index_t UnaryOpSize_ = 8>
|
||||
template <typename Problem>
|
||||
struct BlockGemmBQuantBase
|
||||
{
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
|
||||
static constexpr index_t UnaryOpSize = UnaryOpSize_;
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale)
|
||||
{
|
||||
@@ -42,24 +42,6 @@ struct BlockGemmBQuantBase
|
||||
}
|
||||
return scale_reg_f;
|
||||
}
|
||||
|
||||
// can be inherited from A
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
|
||||
const WarpWindow& warp_window)
|
||||
{
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
|
||||
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// A is block window on shared memory
|
||||
@@ -67,7 +49,9 @@ struct BlockGemmBQuantBase
|
||||
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
|
||||
index_t UnaryOpSize_ = 8>
|
||||
struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
{
|
||||
private:
|
||||
@@ -170,6 +154,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
|
||||
using Base = BlockGemmBQuantBase<Problem_>;
|
||||
|
||||
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
@@ -291,7 +276,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Base::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -301,7 +286,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Base::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user