Merge commit 'f2cfc6b94ee3154697030c4dfa214040bb4af4c9' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-13 19:11:21 +00:00
parent 0997e2eb6d
commit acd5abe4f1
38 changed files with 352 additions and 1888 deletions

View File

@@ -204,7 +204,7 @@ struct tile_window_with_static_distribution
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
const TileWindow_& tile_window,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
@@ -283,7 +283,7 @@ struct tile_window_with_static_distribution
template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
@@ -431,7 +431,7 @@ struct tile_window_with_static_distribution
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
CK_TILE_DEVICE void async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
@@ -515,7 +515,7 @@ struct tile_window_with_static_distribution
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
CK_TILE_DEVICE auto async_load_with_offset(index_t offset,
CK_TILE_DEVICE void async_load_with_offset(index_t offset,
LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
@@ -605,7 +605,7 @@ struct tile_window_with_static_distribution
typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset,
CK_TILE_DEVICE void load_transpose_with_offset(index_t offset,
DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const

View File

@@ -8,16 +8,7 @@
namespace ck_tile {
template <class T>
struct is_pk_int4 : std::false_type
{
};
template <>
struct is_pk_int4<pk_int4_t> : std::true_type
{
};
template <typename ComputeDataType, index_t UnaryOpSize>
template <typename DstDataType, index_t UnaryOpSize>
struct InterleavedPKTypeLoader
{
template <typename WarpWindow, typename WarpTile>
@@ -30,24 +21,30 @@ struct InterleavedPKTypeLoader
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
elementwise_op(warp_tile.get_thread_buffer().template get_as<DstVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
};
template <typename BDataType,
typename ComputeDataType,
template <typename SrcDataType,
typename DstDataType,
index_t UnaryOpSize,
bool LoadTranspose = false,
typename WarpTile,
typename WarpWindow>
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
{
if constexpr(is_pk_int4<std::remove_cv_t<BDataType>>::value)
if constexpr(std::is_same_v<SrcDataType, pk_int4_t>)
{
InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t");
InterleavedPKTypeLoader<DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
}
else if constexpr(LoadTranspose)
{
dst = load_tile_transpose(src);
}
else
{

View File

@@ -94,7 +94,11 @@ struct BlockUniversalGemmAsBsCr
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
@@ -196,8 +200,8 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
@@ -222,22 +226,10 @@ struct BlockUniversalGemmAsBsCr
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -285,8 +277,8 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
@@ -300,30 +292,10 @@ struct BlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else if constexpr(ALoadTranspose)
{
a_warp_tile_ = load_tile_transpose(a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else if constexpr(BLoadTranspose)
{
b_warp_tile_ = load_tile_transpose(b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
}
// C += A * B
@@ -396,8 +368,8 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto BLdsTileDistr =
make_static_tile_distribution(MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
@@ -451,30 +423,10 @@ struct BlockUniversalGemmAsBsCr
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else if constexpr(ALoadTranspose)
{
a_warp_tile_ = load_tile_transpose(a_lds_gemm_window);
}
else
{
load_tile(a_warp_tile_, a_lds_gemm_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else if constexpr(BLoadTranspose)
{
b_warp_tile_ = load_tile_transpose(b_lds_gemm_window);
}
else
{
load_tile(b_warp_tile_, b_lds_gemm_window);
}
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_lds_gemm_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_lds_gemm_window);
}
// C += A * B

View File

@@ -26,8 +26,21 @@ struct GemmPipelineAgBgCrImplBase
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
#if defined(__gfx950__)
static constexpr bool is_a_load_tr = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
static constexpr bool is_b_load_tr = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
// The combination of pk_int4_t and transposed loading causes numerical errors.
// Therefore do not use transposed loading in this case.
static constexpr bool is_a_load_tr = []() {
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else
return std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
}();
static constexpr bool is_b_load_tr = []() {
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else
return std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
}();
#else
static constexpr bool is_a_load_tr = false;
static constexpr bool is_b_load_tr = false;

View File

@@ -33,12 +33,27 @@ template <typename Derived>
struct UniversalGemmBasePolicy
{
#if defined(__gfx950__)
// The combination of pk_int4_t and transposed loading causes numerical errors.
// Therefore do not use transposed loading in this case.
template <typename Problem>
static constexpr bool is_a_load_tr =
std::is_same_v<remove_cvref_t<typename Problem::ALayout>, tensor_layout::gemm::ColumnMajor>;
static constexpr bool is_a_load_tr = []() {
using BDataType = remove_cvref_t<typename Problem::BDataType>;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else
return std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
tensor_layout::gemm::ColumnMajor>;
}();
template <typename Problem>
static constexpr bool is_b_load_tr =
std::is_same_v<remove_cvref_t<typename Problem::BLayout>, tensor_layout::gemm::RowMajor>;
static constexpr bool is_b_load_tr = []() {
using BDataType = remove_cvref_t<typename Problem::BDataType>;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else
return std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
tensor_layout::gemm::RowMajor>;
}();
#else
template <typename Problem>
static constexpr bool is_a_load_tr = false;
@@ -707,8 +722,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
@@ -718,8 +740,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
Problem::UseStructuredSparsity,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;

View File

@@ -5,7 +5,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
@@ -156,7 +155,6 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
using Base = BlockGemmAQuantBase<Problem_>;
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
@@ -447,26 +445,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
}
// C += A * B

View File

@@ -5,7 +5,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
@@ -155,7 +154,6 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using Base = BlockGemmBQuantBase<Problem_>;
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
@@ -273,26 +271,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
}
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
}
// C += A * B

View File

@@ -7,7 +7,6 @@
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"