Support transposed C tile in Aquant (#2679)

The performance of Aquant has increased after enabling transposed C.

Do not need to exchange AQ elements among lanes after enabling
transposed C as one thread only holds data from one row.
This commit is contained in:
Cong Ma
2025-08-28 14:28:09 -06:00
committed by GitHub
parent 0758883fa4
commit 428090f749
10 changed files with 276 additions and 154 deletions

View File

@@ -158,6 +158,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool TransposeC = Problem::TransposeC;
};
public:
@@ -359,63 +360,181 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
if constexpr(Traits::PreshuffleQuant)
{
// A view is created on top of the preshuffled AQ, where each row of the
// view is composed of a row from a warp tile within an AQ block tile.
// Multiple warp tile rows that belong to the same block tile are laid
// out as consecutive rows.
//
// When we need to multiply a C warp tile with an AQ warp tile, thread 0
// in the warp will load AQ_warp_tile[0], thread 1 will load
// AQ_warp_tile[1], and so on, up to thread 63, which will load
// AQ_warp_tile[63]. The VGPR file in the warp acts similarly to LDS in
// this context, but we use cross-lane operations to access the data.
// (Cross-lane operations are faster than using LDS.)
//
// Note that when the size of the AQ warp tile is smaller than the warp
// size, you need to pad the rows in the view to ensure that each thread
// can read one element.
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
constexpr uint32_t kTileRowsOfCPerThread = 4;
if constexpr(Traits::TransposeC) // transposed C
{
static_assert(false,
"It is not supported yet to enable both Preshuffle.");
// TODO:
// A new tile distribution is needed for the Preshuffle and
// Transpose combination. For instance, with mnk at 16x16x32, lanes
// 0-15, 16-31, 32-47, and 48-63 must load the same elements of AQ.
}
else
{
// A view is created on top of the preshuffled AQ, where each row of
// the view is composed of a row from a warp tile within an AQ block
// tile. Multiple warp tile rows that belong to the same block tile
// are laid out as consecutive rows.
//
// When we need to multiply a C warp tile with an AQ warp tile,
// thread 0 in the warp will load AQ_warp_tile[0], thread 1 will
// load AQ_warp_tile[1], and so on, up to thread 63, which will load
// AQ_warp_tile[63]. The VGPR file in the warp acts similarly to LDS
// in this context, but we use cross-lane operations to access the
// data. (Cross-lane operations are faster than using LDS.)
//
// Note that when the size of the AQ warp tile is smaller than the
// warp size, you need to pad the rows in the view to ensure that
// each thread can read one element.
constexpr auto tbuf_offset = number<
typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
constexpr uint32_t kTileRowsOfCPerThread = 4;
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
// For a warp tile of [16x16x32], take thread 0 as an example.
// Its VGPR[0] stores the value from C_tile[0,0], VGPR[1] stores
// C_tile[1,0], VGPR[2] stores C_tile[2,0], and VGPR[3] stores
// C_tile[3,0]. This means VGPR[0] should be multiplied by
// AQ_tile[0, 0], VGPR[1] by AQ_tile[1, 0], VGPR[2] by
// AQ_tile[2, 0], and VGPR[3] by AQ_tile[3, 0].
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
// For a warp tile of [16x16x32], take thread 0 as an
// example. Its VGPR[0] stores the value from C_tile[0,0],
// VGPR[1] stores C_tile[1,0], VGPR[2] stores C_tile[2,0],
// and VGPR[3] stores C_tile[3,0]. This means VGPR[0] should
// be multiplied by AQ_tile[0, 0], VGPR[1] by AQ_tile[1, 0],
// VGPR[2] by AQ_tile[2, 0], and VGPR[3] by AQ_tile[3, 0].
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0]
// from thread 1, ..., and AQ_tile[3, 0] from thread 3.
decltype(threadIdx.x) pull_from_lane = 0;
if constexpr(WarpGemm::kM == 16)
{
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
c_row) *
Traits::QScalesPerBlockRow +
kQScale;
}
else if constexpr(WarpGemm::kM == 32)
{
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
((c_row >> 2) << 3) + (c_row & 0b11)) *
Traits::QScalesPerBlockRow +
kQScale;
}
else
{
static_assert(false, "WarpGemm::kM is not 16 nor 32.");
}
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1,
// 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3.
decltype(threadIdx.x) pull_from_lane = 0;
if constexpr(WarpGemm::kM == 16)
{
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
c_row) *
Traits::QScalesPerBlockRow +
kQScale;
}
else if constexpr(WarpGemm::kM == 32)
{
pull_from_lane =
(__lane_id() / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
((c_row >> 2) << 3) + (c_row & 0b11)) *
Traits::QScalesPerBlockRow +
kQScale;
}
else
{
static_assert(false, "WarpGemm::kM is not 16 nor 32.");
}
auto& scale_reg =
aq_block_tensor.get_thread_buffer()[mIter];
// cross lane ops
// cross lane ops
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<AQDataType, float>)
{
scale_reg_dword =
ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2,
__builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f =
Base::cvt_scale_to_fp32(gathered_scale_reg);
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
});
}
}
else
{
if(Traits::TransposeC) // transposed C
{
constexpr index_t reg_offset = mIter * Traits::AQPerBlock + kQScale;
constexpr auto tbuf_offset = number<
typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
auto& scale_reg = aq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
});
}
else
{
// Need to multiply aquant with accumulated C
//
// The accumulated C tile has the standard distribution. For example
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
// [26,0], [27,0].
//
// These elements are in different rows, need to get the scale value
// for the corresponding row.
// Based on aquant's tile distribution, it can be inferred which
// lane holds the relevant scale. For example, the scales
// corresponding to the 16 elements held by lane 0 are held by lanes
// 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
// respectively.
//
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
// MIters per warp
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
// Reg block offset based on mIter
constexpr index_t reg_block_offset =
((mIter / mIters_per_warp) * Traits::AQPerBlock);
constexpr index_t lane_base_offset =
(mIter % mIters_per_warp) * WarpGemm::kM;
// Scale tensor offset along K
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
constexpr uint32_t kTileRows = 4;
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
constexpr auto tbuf_offset = number<
typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) {
// Multiply by 4 because output is stored in tiles of 4
// x CNLane
constexpr uint32_t row_base =
((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane);
constexpr uint32_t reg_offset_for_row_data =
c_row / WarpGemm::kCMLane;
// Lane index to source scale from
uint32_t src_lane_idx =
lane_base_offset + row_base +
(__lane_id() / WarpGemm::kN * kTileRows);
// Directly index into thread buffer corresponding to
// desired row coefficient
auto& scale_reg =
aq_block_tensor.get_thread_buffer()[src_reg_offset];
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<AQDataType, float>)
@@ -427,97 +546,19 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
// Pull scale data across lanes
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2,
__builtin_bit_cast(int, scale_reg_dword));
src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f *
kA_cvt_scale * kB_cvt_scale);
c_block_tensor.get_thread_buffer()[tbuf_offset +
reg_offset_for_row_data] +=
(c_warp_tensor
.get_thread_buffer()[reg_offset_for_row_data] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
});
}
else
{
// Need to multiply aquant with accumulated C
//
// The accumulated C tile has the standard distribution. For example
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
// [26,0], [27,0].
//
// These elements are in different rows, need to get the scale value
// for the corresponding row.
// Based on aquant's tile distribution, it can be inferred which
// lane holds the relevant scale. For example, the scales corresponding
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
//
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
// MIters per warp
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
// Reg block offset based on mIter
constexpr index_t reg_block_offset =
((mIter / mIters_per_warp) * Traits::AQPerBlock);
constexpr index_t lane_base_offset =
(mIter % mIters_per_warp) * WarpGemm::kM;
// Scale tensor offset along K
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
constexpr uint32_t kTileRows = 4;
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) {
// Multiply by 4 because output is stored in tiles of 4
// x CNLane
constexpr uint32_t row_base =
((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane);
constexpr uint32_t reg_offset_for_row_data =
c_row / WarpGemm::kCMLane;
// Lane index to source scale from
uint32_t src_lane_idx = lane_base_offset + row_base +
(__lane_id() / WarpGemm::kN * kTileRows);
// Directly index into thread buffer corresponding to
// desired row coefficient
auto& scale_reg =
aq_block_tensor.get_thread_buffer()[src_reg_offset];
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<AQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
// Pull scale data across lanes
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
c_block_tensor
.get_thread_buffer()[tbuf_offset + reg_offset_for_row_data] +=
(c_warp_tensor.get_thread_buffer()[reg_offset_for_row_data] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
});
}
}
});
});

View File

@@ -50,7 +50,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
Problem::TransposeC>;
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(PreshuffleQuant)
@@ -70,16 +70,30 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
}
else
{
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
if constexpr(Problem::TransposeC)
{
using TileEncodingPatternTransposeC =
TileDistributionEncodingPatternAQTransposedC<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
VecLoadSize>;
return TileEncodingPatternTransposeC::Make2DStaticTileDistribution();
}
else
{
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
return TileEncodingPattern::Make2DStaticTileDistribution();
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
}
@@ -98,7 +112,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
Problem::TransposeC>;
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
static_assert(std::is_same_v<typename Problem::CDataType, float>);

View File

@@ -18,6 +18,7 @@ template <typename ADataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
bool TransposeC_,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
@@ -50,7 +51,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
using typename Base::BLayout;
using typename Base::CLayout;
static constexpr bool TransposeC = false;
static constexpr bool TransposeC = TransposeC_;
using Base::kBlockSize;
@@ -102,6 +103,7 @@ template <typename ADataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
bool TransposeC_,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
@@ -113,6 +115,7 @@ using GemmAQuantPipelineProblem = GemmAQuantPipelineProblemBase<ADataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,

View File

@@ -113,4 +113,55 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
}
};
template <typename BlockGemmShape,
typename WarpGemm,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize>
struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
static_assert(num_warps == MWarps * NWarps * KWarps);
// KWarps > 1 isn't supported
static_assert(KWarps == 1);
// # of elements per thread
static constexpr index_t X = XPerTile;
static constexpr index_t XR = 2;
// Number of iters per warp
// MIters are indexed using (Y0, Y1)
static constexpr index_t Y0 = MIterPerWarp;
// # of warps in Y dim
static constexpr index_t Y1 = MWarps;
static constexpr index_t Y2 = WarpGemm::kM;
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps, XR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<1, 0>, sequence<0, 1>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
};
} // namespace ck_tile