[CK-Tile] Remove usage of tile partitioner's full gemm shape (#3204)

gemm shape should be used from the pipeline instead (where it gets from a problem description struct)
This commit is contained in:
Max Podkorytov
2025-11-18 09:56:40 -08:00
committed by GitHub
parent ac70206b2c
commit a3a4eb12bd
9 changed files with 31 additions and 36 deletions

View File

@@ -363,8 +363,8 @@ struct FlatmmKernel
template <class KernelArgs>
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
{
constexpr auto N1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<1>{});
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
constexpr auto N1 = BlockGemmShape::WarpTile::at(number<1>{});
constexpr auto K1 = BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;

View File

@@ -369,7 +369,7 @@ struct MoeFlatmmKernel
template <class KernelArgs>
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
constexpr auto K1 = BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;

View File

@@ -386,11 +386,9 @@ template <typename BlockGemmShapeType,
uint32_t TileSwizzleSubM = 8>
struct StreamKTilePartitioner
{
using BlockGemmShape = BlockGemmShapeType;
static constexpr uint32_t MPerBlock = BlockGemmShape::kM;
static constexpr uint32_t NPerBlock = BlockGemmShape::kN;
static constexpr uint32_t KPerBlock = BlockGemmShape::kK;
static constexpr uint32_t MPerBlock = BlockGemmShapeType::kM;
static constexpr uint32_t NPerBlock = BlockGemmShapeType::kN;
static constexpr uint32_t KPerBlock = BlockGemmShapeType::kK;
CK_TILE_HOST_DEVICE StreamKTilePartitioner() noexcept = delete;

View File

@@ -22,11 +22,10 @@ template <typename BlockGemmShapeType,
StreamKReductionStrategy ReductionStrategyType = StreamKReductionStrategy::Atomic>
struct StreamKTilePartitionerBase
{
using BlockGemmShape = BlockGemmShapeType;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t MPerBlock = BlockGemmShapeType::kM;
static constexpr index_t NPerBlock = BlockGemmShapeType::kN;
static constexpr index_t KPerBlock = BlockGemmShapeType::kK;
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);

View File

@@ -325,7 +325,7 @@ struct UniversalGemmKernel
{
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
@@ -584,7 +584,7 @@ struct UniversalGemmKernel
const KernelArgs& kargs,
const index_t k_size)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
const auto& as_tensor_view = generate_tuple(
[&](auto i) {
@@ -617,7 +617,7 @@ struct UniversalGemmKernel
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
@@ -649,7 +649,7 @@ struct UniversalGemmKernel
}
else
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
@@ -675,8 +675,7 @@ struct UniversalGemmKernel
{
index_t kFlatK =
GemmPipeline::BlockGemmShape::flatKPerWarp *
(k_size /
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(

View File

@@ -276,7 +276,7 @@ struct QuantGemmKernel
__device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2);
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
@@ -487,7 +487,7 @@ struct QuantGemmKernel
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
@@ -537,7 +537,7 @@ struct QuantGemmKernel
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
const auto wave_tile_size =
TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
const auto wave_tile_count_x =
ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
@@ -597,7 +597,7 @@ struct QuantGemmKernel
const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
@@ -627,7 +627,7 @@ struct QuantGemmKernel
}
else
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
@@ -649,10 +649,9 @@ struct QuantGemmKernel
{
if constexpr(PreshuffleB)
{
index_t kFlatK =
GemmPipeline::flatKPerWarp *
(splitk_batch_offset.splitted_k /
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatK = GemmPipeline::flatKPerWarp *
(splitk_batch_offset.splitted_k /
GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(
@@ -837,7 +836,7 @@ struct QuantGemmKernel
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
@@ -880,7 +879,7 @@ struct QuantGemmKernel
b_pad_view,
make_tuple(number<GemmPipeline::flatNPerWarp>{},
number<GemmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / TilePartitioner::BlockGemmShape::WarpTile::at(I1)), 0});
{static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
}
else
{

View File

@@ -724,8 +724,8 @@ struct GroupedConvolutionBackwardDataKernel
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
const index_t group_id)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
const auto& a_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(
a_ptr,

View File

@@ -464,7 +464,7 @@ struct GroupedConvolutionBackwardWeightKernel
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1);
@@ -646,8 +646,8 @@ struct GroupedConvolutionBackwardWeightKernel
WeiDataType* c_ptr,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
const auto& a_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(a_ptr,
kargs.a_grid_desc_k_m); // A: out

View File

@@ -745,8 +745,8 @@ struct GroupedConvolutionForwardKernel
const BDescType& b_desc,
const CDescType& c_desc)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
const auto& a_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
}();