mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK TILE] Fix bugs in AQuant preshuffle (#2700)
* [CK TILE] Fix bugs in AQuant preshuffle
- Make Aquant works with block Mx64x256. `M` could be 16, 32, 64
- Make Aquant works with warp 16x16x32 and 32x32x16.
* [CK TILE] Rename Preshuffle to PreshuffleQuant
The new name, PreshuffleQuant, explicitly states the function's purpose:
to preshuffle the quantization matrix.
* [CK TILE Block Scale] Use GemmConfig to save tile properties
- Remove specialization of GemmQuantTypeConfig
- Pass GemmConfig around which contains tile properties. Stop using hard
coded tile properties in `gemm_calc_aquant()`
* [CK TILE Block Scale] Rename GemmConfig used in block scale
- Remove unused GemmConfig
- Rename GemmConfig used in block scale
---------
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -99,15 +99,15 @@ struct AQuantGemmKernelArgs
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct AQuantGemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool Preshuffle = GemmPipeline::Preshuffle;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool PreshuffleQuant = GemmPipeline::PreshuffleQuant;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
|
||||
@@ -422,9 +422,9 @@ struct AQuantGemmKernel
|
||||
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
||||
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
|
||||
aq_pad1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(wave_tile_count_x, aq_y)),
|
||||
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
|
||||
make_pass_through_transform(pad_wave_size)),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
|
||||
@@ -432,7 +432,7 @@ struct AQuantGemmKernel
|
||||
|
||||
const auto& aq_tensor_view = [&]() {
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if constexpr(Preshuffle)
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
return make_preshuffled_aq_tensor_view();
|
||||
}
|
||||
@@ -599,10 +599,8 @@ struct AQuantGemmKernel
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
||||
const AQuantGemmKernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& aq_pad_view = views.at(I1);
|
||||
@@ -628,24 +626,27 @@ struct AQuantGemmKernel
|
||||
|
||||
const auto& aq_block_window = [&]() {
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if constexpr(Preshuffle)
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr auto aqk_per_block =
|
||||
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
constexpr auto tile_window_width = get_warp_size();
|
||||
constexpr auto tile_window_height =
|
||||
TilePartitioner::MPerBlock / TilePartitioner::BlockGemmShape::WarpTile::at(I0);
|
||||
auto block_m_idx = i_m / TilePartitioner::MPerBlock;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_m / warp_m;
|
||||
auto block_m_idx = i_m / block_m;
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_m_idx * kargs.K / TilePartitioner::BlockGemmShape::BlockTile::at(I2),
|
||||
0});
|
||||
{block_m_idx * tile_window_height, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
}();
|
||||
@@ -706,8 +707,7 @@ struct AQuantGemmKernel
|
||||
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, kargs, block_idx_m, block_idx_n);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
@@ -718,7 +718,7 @@ struct AQuantGemmKernel
|
||||
const auto& b_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0);
|
||||
a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
Reference in New Issue
Block a user