[CK-Tile] Remove usage of tile partitioner's full gemm shape (#3204)

gemm shape should be used from the pipeline instead (where it gets from a problem description struct)
This commit is contained in:
Max Podkorytov
2025-11-18 09:56:40 -08:00
committed by GitHub
parent ac70206b2c
commit a3a4eb12bd
9 changed files with 31 additions and 36 deletions

View File

@@ -386,11 +386,9 @@ template <typename BlockGemmShapeType,
uint32_t TileSwizzleSubM = 8>
struct StreamKTilePartitioner
{
using BlockGemmShape = BlockGemmShapeType;
static constexpr uint32_t MPerBlock = BlockGemmShape::kM;
static constexpr uint32_t NPerBlock = BlockGemmShape::kN;
static constexpr uint32_t KPerBlock = BlockGemmShape::kK;
static constexpr uint32_t MPerBlock = BlockGemmShapeType::kM;
static constexpr uint32_t NPerBlock = BlockGemmShapeType::kN;
static constexpr uint32_t KPerBlock = BlockGemmShapeType::kK;
CK_TILE_HOST_DEVICE StreamKTilePartitioner() noexcept = delete;

View File

@@ -22,11 +22,10 @@ template <typename BlockGemmShapeType,
StreamKReductionStrategy ReductionStrategyType = StreamKReductionStrategy::Atomic>
struct StreamKTilePartitionerBase
{
using BlockGemmShape = BlockGemmShapeType;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t MPerBlock = BlockGemmShapeType::kM;
static constexpr index_t NPerBlock = BlockGemmShapeType::kN;
static constexpr index_t KPerBlock = BlockGemmShapeType::kK;
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);

View File

@@ -325,7 +325,7 @@ struct UniversalGemmKernel
{
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
@@ -584,7 +584,7 @@ struct UniversalGemmKernel
const KernelArgs& kargs,
const index_t k_size)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
const auto& as_tensor_view = generate_tuple(
[&](auto i) {
@@ -617,7 +617,7 @@ struct UniversalGemmKernel
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
@@ -649,7 +649,7 @@ struct UniversalGemmKernel
}
else
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
@@ -675,8 +675,7 @@ struct UniversalGemmKernel
{
index_t kFlatK =
GemmPipeline::BlockGemmShape::flatKPerWarp *
(k_size /
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(