[CK-Tile] Remove usage of tile partitioner's full gemm shape (#3204)

gemm shape should be used from the pipeline instead (where it gets from a problem description struct)
This commit is contained in:
Max Podkorytov
2025-11-18 09:56:40 -08:00
committed by GitHub
parent ac70206b2c
commit a3a4eb12bd
9 changed files with 31 additions and 36 deletions

View File

@@ -276,7 +276,7 @@ struct QuantGemmKernel
__device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2);
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
@@ -487,7 +487,7 @@ struct QuantGemmKernel
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
@@ -537,7 +537,7 @@ struct QuantGemmKernel
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
const auto wave_tile_size =
TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
const auto wave_tile_count_x =
ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
@@ -597,7 +597,7 @@ struct QuantGemmKernel
const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
@@ -627,7 +627,7 @@ struct QuantGemmKernel
}
else
{
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
@@ -649,10 +649,9 @@ struct QuantGemmKernel
{
if constexpr(PreshuffleB)
{
index_t kFlatK =
GemmPipeline::flatKPerWarp *
(splitk_batch_offset.splitted_k /
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatK = GemmPipeline::flatKPerWarp *
(splitk_batch_offset.splitted_k /
GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(
@@ -837,7 +836,7 @@ struct QuantGemmKernel
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
@@ -880,7 +879,7 @@ struct QuantGemmKernel
b_pad_view,
make_tuple(number<GemmPipeline::flatNPerWarp>{},
number<GemmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / TilePartitioner::BlockGemmShape::WarpTile::at(I1)), 0});
{static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
}
else
{