feat(block_scale_gemm): Support RRR-R, CRR-R and CCR-C layout for aquant quant mode (#3193)

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* feat(gemm_quant): add RRR and CRR layout support for aquant gemm

* test(gemm_quant): add unit tests for RRR and CRR layout support for aquant gemm

* fix: compilation error on gfx950 by omitting support for the gpu in example and unit tests

* fix: test cases compilation failure due to PR# 2095

* fix: make condition to filter out tests for gfx950 more explicit

* need to support the gfx950

* fix: add layout suppot for gfx950

* Extend pk_int4_t support for block_scale_gemm aquant CR and RR layout (#3277)

* WIP: add support for pk_int4_t for aquant mode layouts RR and CR

* test(block_scale_gemm): add unit tests for CRR and RRR layout when data type is int4 && aquant

* fix: compile time error for gfx950

* fix: minor bug where is_a_load_tr_v() was mising

* feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant (#3318)

* feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant

* test: add unit tests for new layout support CCRC for aquant block scale gemm

* docs: update changelog with new layout support info

* Update CHANGELOG.md

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* refactor: break test instances into multiple cpp files to reduce build time (#3319)

* feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant

* test: add unit tests for new layout support CCRC for aquant block scale gemm

* refactor: break test instances into multiple cpp files to reduce build time

* chore: rename file for better code readability

* fix: merge conflict resolution

* fix: remove memory pipeline because new layout is not compatible

* build: resolve build errors for gfx950 by modifying is_a_load_tr() & is_b_load_tr()

* refactor: address review comments

* solve the conflict

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Aviral Goel
2025-12-03 02:59:07 +04:00
committed by GitHub
parent 2c284a1780
commit 6cb0bc2d11
22 changed files with 603 additions and 289 deletions

View File

@@ -48,7 +48,7 @@ CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
}
else
{
dst = load_tile(src);
load_tile(dst, src);
}
}

View File

@@ -26,18 +26,32 @@ struct GemmPipelineAgBgCrImplBase
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
#if defined(__gfx950__)
// The combination of pk_int4_t and transposed loading causes numerical errors.
// The combination of pk_int4_t and transposed loading causes compilation errors.
// Therefore do not use transposed loading in this case.
// Also, transpose load (ds_read_tr) requires specific tile distribution patterns
// that only work for certain K warp tile sizes based on data type size:
// - For 1-byte types (fp8/bf8): K warp tile <= 64
// - For 2-byte types (fp16/bf16): K warp tile <= 32
static constexpr bool is_a_load_tr = []() {
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
return std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
}();
static constexpr bool is_b_load_tr = []() {
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
return std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
}();
@@ -93,19 +107,21 @@ struct GemmPipelineAgBgCrImplBase
load_tile(dst_block_tile, lds_tile_window);
}
template <typename OverrideADataType = ADataType, typename OverrideBDataType = BDataType>
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{
// A tile in LDS
ADataType* __restrict__ p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
OverrideADataType* __restrict__ p_a_lds = static_cast<OverrideADataType*>(p_smem);
constexpr auto a_lds_block_desc =
Policy::template MakeALdsBlockDescriptor<Problem, OverrideADataType>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size(), 16);
// B tile in LDS
BDataType* __restrict__ p_b_lds = static_cast<BDataType*>(
OverrideBDataType* __restrict__ p_b_lds = static_cast<OverrideBDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);

View File

@@ -18,7 +18,8 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
template <typename Problem>
template <typename Problem,
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;

View File

@@ -37,11 +37,22 @@ struct UniversalGemmBasePolicy
#if defined(__gfx950__)
// The combination of pk_int4_t and transposed loading causes numerical errors.
// Therefore do not use transposed loading in this case.
// Also, transpose load (ds_read_tr) requires specific tile distribution patterns
// that only work for certain K warp tile sizes based on data type size:
// - For 1-byte types (fp8/bf8): K warp tile <= 64
// - For 2-byte types (fp16/bf16): K warp tile <= 32
template <typename Problem>
static constexpr bool is_a_load_tr = []() {
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
// Max K warp tile for transpose load based on data type size
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
return std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
tensor_layout::gemm::ColumnMajor>;
@@ -49,9 +60,15 @@ struct UniversalGemmBasePolicy
template <typename Problem>
static constexpr bool is_b_load_tr = []() {
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
// Max K warp tile for transpose load based on data type size
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
return std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
tensor_layout::gemm::RowMajor>;
@@ -87,13 +104,12 @@ struct UniversalGemmBasePolicy
return DefaultBTileAccessPattern;
}
template <typename Problem>
template <typename Problem,
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = OverrideADataType;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();

View File

@@ -435,12 +435,22 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
// while ADatatype might not be the same as BDataType at the time of problem
// initialization, we can safely use BDataType here because when A would be int4 we will
// ensure A is converted to BDataType prior to loading
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
}
// C += A * B
@@ -522,11 +532,16 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
MakeCBlockTile();
}
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}
// C += A * B

View File

@@ -414,7 +414,6 @@ struct QuantGemmKernel
if constexpr(kQuantType == QuantType::AQuantGrouped)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
@@ -655,13 +654,24 @@ struct QuantGemmKernel
}
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK_A),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK_A),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
else // Column major AQ
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.QK_A, kargs.M), // Swapped dimensions
make_tuple(kargs.stride_AQ, 1), // Same stride pattern
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
@@ -946,14 +956,21 @@ struct QuantGemmKernel
}
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
return make_tile_window(
aq_pad_view,
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
{i_m, 0});
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto block_m = TilePartitioner::MPerBlock;
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(aq_pad_view,
make_tuple(number<block_m>{}, number<aqk_per_block>{}),
{i_m, 0});
}
else // Column major AQ
{
return make_tile_window(aq_pad_view,
make_tuple(number<aqk_per_block>{}, number<block_m>{}),
{0, i_m});
}
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{

View File

@@ -20,8 +20,6 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
using BlockGemmShape = typename Base::BlockGemmShape;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
@@ -36,8 +34,6 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
CK_TILE_DEVICE constexpr auto
GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
auto aq_copy_dram_window =
make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
aq_dram_block_window_tmp.get_window_lengths(),

View File

@@ -18,13 +18,11 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ()
{
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
static_assert(std::is_same_v<AQLayout, ck_tile::tensor_layout::gemm::RowMajor>);
return GetABQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
}
@@ -49,7 +47,6 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpTile::at(I2),
Problem::TransposeC>;
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(PreshuffleQuant)
{
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<
@@ -68,6 +65,8 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
{
if constexpr(Problem::TransposeC)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>,
"TransposeC currently only supports RowMajor layout");
using TileEncodingPatternTransposeC =
tile_distribution_encoding_pattern_aq_transposed_c<BlockGemmShape,
WarpGemm,
@@ -79,16 +78,34 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
}
else
{
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
// !Problem::TransposeC
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_aq<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
return TileEncodingPattern::make_2d_static_tile_distribution();
}
else
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_aq<BlockGemmShape,
WarpGemm,
BlockSize,
KPerBlockAQ, // YPerTile
MPerBlock, // XPerTile
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution_transposed();
}
}
}
}

View File

@@ -77,6 +77,9 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
using Base::PrefetchStages;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
@@ -161,6 +164,16 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
using Base = PipelineImplBase;
template <typename ADramWindow, typename ABlockTile_>
CK_TILE_DEVICE static void LoadAndConvertATile(ABlockTile_& a_block_tile,
const ADramWindow& a_dram_window)
{
using DestDataType = typename ABlockTile_::DataType;
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
@@ -177,6 +190,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
@@ -192,8 +206,6 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
@@ -211,7 +223,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex;
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto&& [a_lds_block, b_lds_block] =
Base::template GetABLdsTensorViews<BDataType, BDataType>(p_smem);
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
@@ -228,8 +241,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
// while ADatatype might not be the same as BDataType at the time of problem
// initialization, we can safely use BDataType here because when A would be int4 we will
// ensure A is converted to BDataType prior to loading
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
decltype(make_static_distributed_tensor<BDataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
using AQBlockTile =
@@ -251,23 +267,25 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
// only row_major for AQ
const AQDramTileWindowStep aq_dram_tile_window_step =
PreshuffleQuant ? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
BlockGemm::WarpGemm::kM,
0)
: make_array(0, KPerBlockAQ);
PreshuffleQuant
? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
BlockGemm::WarpGemm::kM,
0)
: (is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ));
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(
aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
@@ -276,10 +294,10 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
@@ -288,12 +306,14 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
__builtin_amdgcn_sched_barrier(0);
@@ -304,9 +324,9 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
block_sync_lds();
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
@@ -315,7 +335,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -327,7 +347,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2],
aq_copy_dram_window,
@@ -340,7 +361,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
__builtin_amdgcn_sched_barrier(0);
i += 1;
@@ -363,9 +385,9 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
currIdx = (currIdx + 1) % 2;
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
@@ -374,7 +396,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -386,7 +408,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(
c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
}
@@ -405,7 +428,10 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
// Note: a_element_func takes BDataType (not ADataType) because A tiles are
// converted from ADataType (e.g., pk_int4_t) to BDataType (e.g., fp8) in
// LoadAndConvertATile before the element function is applied.
[](const BDataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,

View File

@@ -110,6 +110,27 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding
sequence<0, 0>>{});
}
}
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution_transposed()
{
constexpr index_t Y0 = YPerTile;
constexpr index_t X0 = 1;
constexpr index_t X1 = MIterPerWarp ? MIterPerWarp : 1;
constexpr index_t X2 = MWarps;
constexpr index_t X3 = WarpGemm::kM;
static_assert(X3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
static_assert(X0 * X1 * X2 * X3 == XPerTile,
"X0, X1, X2, X3 must cover the blocktile along X.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0>, sequence<X0, X1, X2, X3>>,
tuple<sequence<2, 0>, sequence<2, 2>>,
tuple<sequence<2, 0>, sequence<0, 3>>,
sequence<2, 1>,
sequence<1, 0>>{});
}
};
template <typename BlockGemmShape,