mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[CK_Tile] Support for preshuffle weight(B) quant tensor for block scale gemm (#3165)
* formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * addressing review comments * fixing CI issue * addressing reveiw comments * formatting * formatting * fixing aquant operator overlaoding * formatting --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -271,6 +271,94 @@ struct QuantGemmKernel
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
private:
|
||||
CK_TILE_DEVICE static constexpr index_t get_padding_size(index_t length, index_t alignment)
|
||||
{
|
||||
return ck_tile::integer_least_multiple(length, alignment) - length;
|
||||
};
|
||||
// ===================================================================
|
||||
// Helper: Create Pre-shuffled Quantization Tensor Descriptor
|
||||
// ===================================================================
|
||||
template <index_t KPerBlockBQ,
|
||||
index_t NPerBlock,
|
||||
index_t WarpTileN,
|
||||
index_t GetVectorSizeBQ,
|
||||
typename BQDataType_>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B)
|
||||
{
|
||||
// Step 1: Calculate base BQ tensor dimensions
|
||||
// ----------------------------------------------------------
|
||||
// bq_x: Number of quantization groups in N dimension
|
||||
// = N * KPerBlockBQ, where KPerBlockBQ is the number of
|
||||
// K-dimension groups per block
|
||||
// bq_y: Number of quantization groups in K dimension
|
||||
// = Total K groups (QK_B) / groups per block
|
||||
const auto bq_x = N * KPerBlockBQ;
|
||||
const auto bq_y = QK_B / KPerBlockBQ;
|
||||
|
||||
const auto bq_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(bq_y, bq_x), make_tuple(bq_x, 1), number<GetVectorSizeBQ>{}, number<1>{});
|
||||
|
||||
// Step 2: First padding transformation (block-level alignment)
|
||||
// ----------------------------------------------------------
|
||||
// Pad the X dimension to be a multiple of block_tile_size to ensure
|
||||
// each thread block can process complete tiles without edge cases
|
||||
const auto block_tile_size = NPerBlock * KPerBlockBQ;
|
||||
const auto bq_pad0_desc = transform_tensor_descriptor(
|
||||
bq_desc,
|
||||
make_tuple(make_pass_through_transform(bq_y),
|
||||
make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Step 3: Unmerge transformation (wave-level decomposition)
|
||||
// ----------------------------------------------------------
|
||||
// Split the X dimension into [wave_tile_count_x, wave_tile_size]
|
||||
// This separates the work into tiles that can be processed by
|
||||
// individual warps/waves
|
||||
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1];
|
||||
const auto wave_tile_size = WarpTileN * KPerBlockBQ;
|
||||
const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size);
|
||||
|
||||
const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
bq_pad0_desc,
|
||||
make_tuple(make_pass_through_transform(bq_y),
|
||||
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
// Step 4: Second padding transformation (warp-level alignment)
|
||||
// ----------------------------------------------------------
|
||||
// Pad wave_tile_size to be a multiple of warp_size (typically 32 or 64)
|
||||
// This ensures coalesced memory accesses within each warp
|
||||
const auto bq_pad1_desc = transform_tensor_descriptor(
|
||||
bq_unmerge_pad0_desc,
|
||||
make_tuple(make_pass_through_transform(bq_y),
|
||||
make_pass_through_transform(wave_tile_count_x),
|
||||
make_right_pad_transform(wave_tile_size,
|
||||
get_padding_size(wave_tile_size, get_warp_size()))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
// Step 5: Final merge transformation (prepare for indexing)
|
||||
// ----------------------------------------------------------
|
||||
// Merge [bq_y, wave_tile_count_x] into a single outer dimension
|
||||
// This creates a 2D layout: [merged_outer_dim, pad_wave_size]
|
||||
// where merged_outer_dim = bq_y * wave_tile_count_x
|
||||
// This layout facilitates efficient block-to-data mapping
|
||||
const auto pad_wave_size = ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
||||
const auto bq_merge_pad1_desc = transform_tensor_descriptor(
|
||||
bq_pad1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)),
|
||||
make_pass_through_transform(pad_wave_size)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(bq_ptr, bq_merge_pad1_desc);
|
||||
}
|
||||
|
||||
public:
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
|
||||
@@ -509,17 +597,12 @@ struct QuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto get_padding_size = [](index_t length, index_t alignment) {
|
||||
return ck_tile::integer_least_multiple(length, alignment) - length;
|
||||
};
|
||||
|
||||
const auto& aq_tensor_view = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
|
||||
|
||||
const auto aq_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
|
||||
make_tuple(aq_x, 1),
|
||||
@@ -540,6 +623,7 @@ struct QuantGemmKernel
|
||||
GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
|
||||
const auto wave_tile_count_x =
|
||||
ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
|
||||
|
||||
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
aq_pad0_desc,
|
||||
make_tuple(
|
||||
@@ -686,14 +770,27 @@ struct QuantGemmKernel
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
|
||||
make_tuple(1, kargs.stride_BQ),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
|
||||
return MakePreshuffledQuantTensorView<
|
||||
GemmPipeline::KPerBlockBQ,
|
||||
GemmPipeline::NPerBlock,
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
|
||||
GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
|
||||
make_tuple(1, kargs.stride_BQ),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -910,13 +1007,33 @@ struct QuantGemmKernel
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_n / warp_n;
|
||||
auto block_n_idx = i_n / block_n;
|
||||
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx * tile_window_height, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -979,14 +1096,24 @@ struct QuantGemmKernel
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(I1);
|
||||
index_t m = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
m = kargs.M;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
|
||||
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0, m);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
index_t n = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
n = kargs.N;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant ||
|
||||
kQuantType == QuantType::TensorQuant)
|
||||
@@ -1074,12 +1201,18 @@ struct QuantGemmKernel
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
index_t n = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
n = kargs.N;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
smem_ptr_1,
|
||||
n);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1109,7 +1242,6 @@ struct QuantGemmKernel
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
|
||||
Reference in New Issue
Block a user