mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_Tile] Support for group size 128 for Preshuffle quant for 2d block scale gemm (#3462)
* formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * debugging permuteN * debugging * debugging PermuteN * initial commit * resolving merge conflicts * adding test cases * initial commit with prints * debugging * fine-grained working * debugging medium grained * fixing the tile window * formatting * enabling prefill shapes * working prefill shapes * formatted * clean up * code cleanup * bug fix after merging with develop * G128 working for both prefill and decode shapes for preshufflequant * clean up after merging with develop * fixing group 64 for decode shapes * non preshufflequant working for group size 128 * enable preshuffleb and preshufflequant with variour group sizes * reduce build time by splitting example into diff datatype files * Adding tests for preshuffleQuant * address review comment * fix for gfx1201 * compile time fix for gfx1201 * clang formatted --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com> Co-authored-by: Agarwal <khuagarw@ctr2-alola-login-03.amd.com>
This commit is contained in:
@@ -887,23 +887,27 @@ struct QuantGemmKernel
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr auto block_n =
|
||||
TilePartitioner::NPerBlock /
|
||||
QuantGroupSize::kN; // Number of N-dimension quantization groups per block
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(
|
||||
I1); // Number of N-dimension elements per warp
|
||||
constexpr auto warp_per_group =
|
||||
(QuantGroupSize::kN <
|
||||
warp_n) // Determine how many warps share the same scale in N-dimension
|
||||
? (warp_n / QuantGroupSize::kN)
|
||||
: (QuantGroupSize::kN / warp_n);
|
||||
constexpr auto bqk_per_block =
|
||||
TilePartitioner::KPerBlock /
|
||||
QuantGroupSize::kK; // Number of K-dimension quantization groups per block
|
||||
constexpr auto
|
||||
tile_window_width = // The pre-shuffled layout flattens warp_n ×
|
||||
// bqk_per_block scales per row, Padded up to warp_size
|
||||
// to ensure coalesced memory access.
|
||||
|
||||
// Number of N-dimension quantization groups per block
|
||||
constexpr auto block_n = (QuantGroupSize::kN <= TilePartitioner::NPerBlock)
|
||||
? TilePartitioner::NPerBlock / QuantGroupSize::kN
|
||||
: QuantGroupSize::kN / TilePartitioner::NPerBlock;
|
||||
|
||||
// Number of N-dimension elements per warp
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
// Determine how many warps share the same scale in N-dimension
|
||||
constexpr auto warp_per_group = (QuantGroupSize::kN < warp_n)
|
||||
? (warp_n / QuantGroupSize::kN)
|
||||
: (QuantGroupSize::kN / warp_n);
|
||||
|
||||
// Number of K-dimension quantization groups per block
|
||||
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
// The pre-shuffled layout flattens warp_n ×
|
||||
// bqk_per_block scales per row, Padded up to warp_size
|
||||
// to ensure coalesced memory access.
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
|
||||
|
||||
// Adapts based on fine vs coarse quantization granularity:
|
||||
@@ -916,23 +920,42 @@ struct QuantGemmKernel
|
||||
// height = block_n
|
||||
constexpr auto tile_window_height =
|
||||
(QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
|
||||
auto block_n_idx =
|
||||
i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a
|
||||
// block index.
|
||||
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx * tile_window_height, 0});
|
||||
auto block_n_idx = i_n / TilePartitioner::NPerBlock;
|
||||
|
||||
// For decode shapes GN: 128, Blocks needs to repeat 0,0,1,1,2,2 ...
|
||||
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
|
||||
{
|
||||
block_n_idx = block_n_idx >> 1;
|
||||
}
|
||||
|
||||
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx * tile_window_height, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto tensor_dim =
|
||||
(QuantGroupSize::kN <= TilePartitioner::NPerBlock)
|
||||
? TilePartitioner::NPerBlock / QuantGroupSize::kN
|
||||
: 1;
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
number<tensor_dim>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
}
|
||||
else
|
||||
@@ -940,7 +963,7 @@ struct QuantGemmKernel
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
make_tuple(number<tensor_dim>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user