Support transposed C tile in Aquant (#2679)

The performance of Aquant has increased after enabling transposed C.

Do not need to exchange AQ elements among lanes after enabling
transposed C as one thread only holds data from one row.
This commit is contained in:
Cong Ma
2025-08-28 14:28:09 -06:00
committed by GitHub
parent 0758883fa4
commit 428090f749
10 changed files with 276 additions and 154 deletions

View File

@@ -50,7 +50,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
Problem::TransposeC>;
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(PreshuffleQuant)
@@ -70,16 +70,30 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
}
else
{
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
if constexpr(Problem::TransposeC)
{
using TileEncodingPatternTransposeC =
TileDistributionEncodingPatternAQTransposedC<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
VecLoadSize>;
return TileEncodingPatternTransposeC::Make2DStaticTileDistribution();
}
else
{
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
return TileEncodingPattern::Make2DStaticTileDistribution();
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
}
@@ -98,7 +112,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
Problem::TransposeC>;
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
static_assert(std::is_same_v<typename Problem::CDataType, float>);

View File

@@ -18,6 +18,7 @@ template <typename ADataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
bool TransposeC_,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
@@ -50,7 +51,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
using typename Base::BLayout;
using typename Base::CLayout;
static constexpr bool TransposeC = false;
static constexpr bool TransposeC = TransposeC_;
using Base::kBlockSize;
@@ -102,6 +103,7 @@ template <typename ADataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
bool TransposeC_,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
@@ -113,6 +115,7 @@ using GemmAQuantPipelineProblem = GemmAQuantPipelineProblemBase<ADataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,

View File

@@ -113,4 +113,55 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
}
};
template <typename BlockGemmShape,
typename WarpGemm,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize>
struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM);
static_assert(num_warps == MWarps * NWarps * KWarps);
// KWarps > 1 isn't supported
static_assert(KWarps == 1);
// # of elements per thread
static constexpr index_t X = XPerTile;
static constexpr index_t XR = 2;
// Number of iters per warp
// MIters are indexed using (Y0, Y1)
static constexpr index_t Y0 = MIterPerWarp;
// # of warps in Y dim
static constexpr index_t Y1 = MWarps;
static constexpr index_t Y2 = WarpGemm::kM;
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps, XR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<1, 0>, sequence<0, 1>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
};
} // namespace ck_tile