mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
feat(block_scale_gemm): Support RRR-R, CRR-R and CCR-C layout for aquant quant mode (#3193)
* [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * feat(gemm_quant): add RRR and CRR layout support for aquant gemm * test(gemm_quant): add unit tests for RRR and CRR layout support for aquant gemm * fix: compilation error on gfx950 by omitting support for the gpu in example and unit tests * fix: test cases compilation failure due to PR# 2095 * fix: make condition to filter out tests for gfx950 more explicit * need to support the gfx950 * fix: add layout suppot for gfx950 * Extend pk_int4_t support for block_scale_gemm aquant CR and RR layout (#3277) * WIP: add support for pk_int4_t for aquant mode layouts RR and CR * test(block_scale_gemm): add unit tests for CRR and RRR layout when data type is int4 && aquant * fix: compile time error for gfx950 * fix: minor bug where is_a_load_tr_v() was mising * feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant (#3318) * feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant * test: add unit tests for new layout support CCRC for aquant block scale gemm * docs: update changelog with new layout support info * Update CHANGELOG.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * refactor: break test instances into multiple cpp files to reduce build time (#3319) * feat(block_scale_gemm): Add layout Col-Col-Row-Col (ABC-Aquant) for tensors in aquant * test: add unit tests for new layout support CCRC for aquant block scale gemm * refactor: break test instances into multiple cpp files to reduce build time * chore: rename file for better code readability * fix: merge conflict resolution * fix: remove memory pipeline because new layout is not compatible * build: resolve build errors for gfx950 by modifying is_a_load_tr() & is_b_load_tr() * refactor: address review comments * solve the conflict --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -48,7 +48,7 @@ CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
|
||||
}
|
||||
else
|
||||
{
|
||||
dst = load_tile(src);
|
||||
load_tile(dst, src);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,18 +26,32 @@ struct GemmPipelineAgBgCrImplBase
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
#if defined(__gfx950__)
|
||||
// The combination of pk_int4_t and transposed loading causes numerical errors.
|
||||
// The combination of pk_int4_t and transposed loading causes compilation errors.
|
||||
// Therefore do not use transposed loading in this case.
|
||||
// Also, transpose load (ds_read_tr) requires specific tile distribution patterns
|
||||
// that only work for certain K warp tile sizes based on data type size:
|
||||
// - For 1-byte types (fp8/bf8): K warp tile <= 64
|
||||
// - For 2-byte types (fp16/bf16): K warp tile <= 32
|
||||
static constexpr bool is_a_load_tr = []() {
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
}();
|
||||
|
||||
static constexpr bool is_b_load_tr = []() {
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
}();
|
||||
@@ -93,19 +107,21 @@ struct GemmPipelineAgBgCrImplBase
|
||||
load_tile(dst_block_tile, lds_tile_window);
|
||||
}
|
||||
|
||||
template <typename OverrideADataType = ADataType, typename OverrideBDataType = BDataType>
|
||||
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
|
||||
{
|
||||
// A tile in LDS
|
||||
ADataType* __restrict__ p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
OverrideADataType* __restrict__ p_a_lds = static_cast<OverrideADataType*>(p_smem);
|
||||
constexpr auto a_lds_block_desc =
|
||||
Policy::template MakeALdsBlockDescriptor<Problem, OverrideADataType>();
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// TODO: LDS alignment should come from Policy!
|
||||
constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
|
||||
sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
|
||||
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size(), 16);
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* __restrict__ p_b_lds = static_cast<BDataType*>(
|
||||
OverrideBDataType* __restrict__ p_b_lds = static_cast<OverrideBDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
@@ -18,7 +18,8 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
|
||||
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem,
|
||||
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
|
||||
@@ -37,11 +37,22 @@ struct UniversalGemmBasePolicy
|
||||
#if defined(__gfx950__)
|
||||
// The combination of pk_int4_t and transposed loading causes numerical errors.
|
||||
// Therefore do not use transposed loading in this case.
|
||||
// Also, transpose load (ds_read_tr) requires specific tile distribution patterns
|
||||
// that only work for certain K warp tile sizes based on data type size:
|
||||
// - For 1-byte types (fp8/bf8): K warp tile <= 64
|
||||
// - For 2-byte types (fp16/bf16): K warp tile <= 32
|
||||
template <typename Problem>
|
||||
static constexpr bool is_a_load_tr = []() {
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
// Max K warp tile for transpose load based on data type size
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
|
||||
tensor_layout::gemm::ColumnMajor>;
|
||||
@@ -49,9 +60,15 @@ struct UniversalGemmBasePolicy
|
||||
|
||||
template <typename Problem>
|
||||
static constexpr bool is_b_load_tr = []() {
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
// Max K warp tile for transpose load based on data type size
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
|
||||
tensor_layout::gemm::RowMajor>;
|
||||
@@ -87,13 +104,12 @@ struct UniversalGemmBasePolicy
|
||||
return DefaultBTileAccessPattern;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem,
|
||||
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
|
||||
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = OverrideADataType;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
@@ -435,12 +435,22 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
|
||||
// while ADatatype might not be the same as BDataType at the time of problem
|
||||
// initialization, we can safely use BDataType here because when A would be int4 we will
|
||||
// ensure A is converted to BDataType prior to loading
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
|
||||
a_warp_tile_, a_block_window);
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
|
||||
b_warp_tile_, b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
@@ -522,11 +532,16 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
MakeCBlockTile();
|
||||
}
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
|
||||
@@ -414,7 +414,6 @@ struct QuantGemmKernel
|
||||
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
@@ -655,13 +654,24 @@ struct QuantGemmKernel
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK_A),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK_A),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else // Column major AQ
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.QK_A, kargs.M), // Swapped dimensions
|
||||
make_tuple(kargs.stride_AQ, 1), // Same stride pattern
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
@@ -946,14 +956,21 @@ struct QuantGemmKernel
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
|
||||
{i_m, 0});
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(aq_pad_view,
|
||||
make_tuple(number<block_m>{}, number<aqk_per_block>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else // Column major AQ
|
||||
{
|
||||
return make_tile_window(aq_pad_view,
|
||||
make_tuple(number<aqk_per_block>{}, number<block_m>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
|
||||
@@ -20,8 +20,6 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
@@ -36,8 +34,6 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
auto aq_copy_dram_window =
|
||||
make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
aq_dram_block_window_tmp.get_window_lengths(),
|
||||
|
||||
@@ -18,13 +18,11 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ()
|
||||
{
|
||||
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
|
||||
static_assert(std::is_same_v<AQLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
return GetABQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
|
||||
}
|
||||
|
||||
@@ -49,7 +47,6 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<
|
||||
@@ -68,6 +65,8 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
{
|
||||
if constexpr(Problem::TransposeC)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>,
|
||||
"TransposeC currently only supports RowMajor layout");
|
||||
using TileEncodingPatternTransposeC =
|
||||
tile_distribution_encoding_pattern_aq_transposed_c<BlockGemmShape,
|
||||
WarpGemm,
|
||||
@@ -79,16 +78,34 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
}
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
// !Problem::TransposeC
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_aq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_aq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
KPerBlockAQ, // YPerTile
|
||||
MPerBlock, // XPerTile
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution_transposed();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +77,9 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
using Base::PrefetchStages;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
@@ -161,6 +164,16 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <typename ADramWindow, typename ABlockTile_>
|
||||
CK_TILE_DEVICE static void LoadAndConvertATile(ABlockTile_& a_block_tile,
|
||||
const ADramWindow& a_dram_window)
|
||||
{
|
||||
using DestDataType = typename ABlockTile_::DataType;
|
||||
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
@@ -177,6 +190,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -192,8 +206,6 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
@@ -211,7 +223,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
|
||||
auto&& [a_lds_block, b_lds_block] =
|
||||
Base::template GetABLdsTensorViews<BDataType, BDataType>(p_smem);
|
||||
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
@@ -228,8 +241,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// while ADatatype might not be the same as BDataType at the time of problem
|
||||
// initialization, we can safely use BDataType here because when A would be int4 we will
|
||||
// ensure A is converted to BDataType prior to loading
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
decltype(make_static_distributed_tensor<BDataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
using AQBlockTile =
|
||||
@@ -251,23 +267,25 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
|
||||
// only row_major for AQ
|
||||
const AQDramTileWindowStep aq_dram_tile_window_step =
|
||||
PreshuffleQuant ? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
|
||||
BlockGemm::WarpGemm::kM,
|
||||
0)
|
||||
: make_array(0, KPerBlockAQ);
|
||||
PreshuffleQuant
|
||||
? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
|
||||
BlockGemm::WarpGemm::kM,
|
||||
0)
|
||||
: (is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ));
|
||||
|
||||
// DRAM prefetch (global read 0)
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step);
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
@@ -276,10 +294,10 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
@@ -288,12 +306,14 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -304,9 +324,9 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
@@ -315,7 +335,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -327,7 +347,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2],
|
||||
aq_copy_dram_window,
|
||||
@@ -340,7 +361,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
@@ -363,9 +385,9 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
|
||||
currIdx = (currIdx + 1) % 2;
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
@@ -374,7 +396,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -386,7 +408,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(
|
||||
c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
@@ -405,7 +428,10 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
// Note: a_element_func takes BDataType (not ADataType) because A tiles are
|
||||
// converted from ADataType (e.g., pk_int4_t) to BDataType (e.g., fp8) in
|
||||
// LoadAndConvertATile before the element function is applied.
|
||||
[](const BDataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
aq_dram_block_window_tmp,
|
||||
|
||||
@@ -110,6 +110,27 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution_transposed()
|
||||
{
|
||||
|
||||
constexpr index_t Y0 = YPerTile;
|
||||
constexpr index_t X0 = 1;
|
||||
constexpr index_t X1 = MIterPerWarp ? MIterPerWarp : 1;
|
||||
constexpr index_t X2 = MWarps;
|
||||
constexpr index_t X3 = WarpGemm::kM;
|
||||
|
||||
static_assert(X3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
|
||||
static_assert(X0 * X1 * X2 * X3 == XPerTile,
|
||||
"X0, X1, X2, X3 must cover the blocktile along X.");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>,
|
||||
tuple<sequence<Y0>, sequence<X0, X1, X2, X3>>,
|
||||
tuple<sequence<2, 0>, sequence<2, 2>>,
|
||||
tuple<sequence<2, 0>, sequence<0, 3>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 0>>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BlockGemmShape,
|
||||
|
||||
Reference in New Issue
Block a user