Preshuffle AQ matrix in block scale gemm (#2624)

* Preshuffle AQ matrix in block scale gemm

* turns the output to fp16. Increase the repetition time.

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Cong Ma
2025-08-12 22:32:51 -06:00
committed by GitHub
parent 0f42a92fc1
commit 452791a3ba
13 changed files with 667 additions and 228 deletions

View File

@@ -3,11 +3,14 @@
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
@@ -104,6 +107,7 @@ struct AQuantGemmKernel
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr bool Preshuffle = GemmPipeline::Preshuffle;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
@@ -157,7 +161,7 @@ struct AQuantGemmKernel
__device__ SplitKBatchOffset(const AQuantGemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2);
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
@@ -372,14 +376,75 @@ struct AQuantGemmKernel
}
}();
const auto get_padding_size = [](index_t length, index_t alignment) {
return ck_tile::integer_least_multiple(length, alignment) - length;
};
const auto& make_preshuffled_aq_tensor_view = [&]() {
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
const auto aq_y = kargs.QK / GemmPipeline::KPerBlockAQ;
const auto aq_desc =
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
make_tuple(aq_x, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
const auto aq_pad0_desc = transform_tensor_descriptor(
aq_desc,
make_tuple(make_pass_through_transform(aq_y),
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
const auto wave_tile_size =
TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
aq_pad0_desc,
make_tuple(make_pass_through_transform(aq_y),
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto aq_pad1_desc = transform_tensor_descriptor(
aq_unmerge_pad0_desc,
make_tuple(make_pass_through_transform(aq_y),
make_pass_through_transform(wave_tile_count_x),
make_right_pad_transform(
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
const auto pad_wave_size =
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
aq_pad1_desc,
make_tuple(make_merge_transform(make_tuple(wave_tile_count_x, aq_y)),
make_pass_through_transform(pad_wave_size)),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
};
const auto& aq_tensor_view = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
if constexpr(Preshuffle)
{
return make_preshuffled_aq_tensor_view();
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
}();
const auto& b_tensor_view = [&]() {
@@ -491,16 +556,7 @@ struct AQuantGemmKernel
}
}();
const auto& aq_pad_view = [&]() {
const auto& aq_tensor_view = views.at(I1);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
return pad_tensor_view(
aq_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
// TODO: Add support for padding.
sequence<false, false>{});
}();
const auto& aq_pad_view = [&]() { return views.at(I1); }();
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I2);
@@ -543,8 +599,10 @@ struct AQuantGemmKernel
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
const AQuantGemmKernelArgs& kargs,
const index_t i_m,
const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
@@ -570,11 +628,26 @@ struct AQuantGemmKernel
const auto& aq_block_window = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return make_tile_window(
aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
{i_m, 0});
if constexpr(Preshuffle)
{
constexpr auto tile_window_width = get_warp_size();
constexpr auto tile_window_height =
TilePartitioner::MPerBlock / TilePartitioner::BlockGemmShape::WarpTile::at(I0);
auto block_m_idx = i_m / TilePartitioner::MPerBlock;
return make_tile_window(
aq_pad_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * kargs.K / TilePartitioner::BlockGemmShape::BlockTile::at(I2),
0});
}
else
{
return make_tile_window(
aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
{i_m, 0});
}
}();
const auto& b_block_window = [&]() {
@@ -633,7 +706,8 @@ struct AQuantGemmKernel
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
auto gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, kargs, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));