Fix quant scale matrix layout for block scale gemm (#3079)

* Adding support for TiledPermuteN

* Adding test

* moving shuffle functions to common place

* resolving commit hook

* fix formatting
This commit is contained in:
Khushbu Agarwal
2025-10-27 13:56:07 -07:00
committed by GitHub
parent a46b725992
commit b11f53a484
10 changed files with 35 additions and 31 deletions

View File

@@ -5,7 +5,7 @@
#include <random>
#include <stdexcept>
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/shuffle_utils.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
template <typename GemmConfig,
typename TypeConfig,

View File

@@ -46,7 +46,7 @@
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_transpose.hpp"
#include "ck_tile/host/rotating_buffers.hpp"
#include "ck_tile/host/shuffle_utils.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
#include "ck_tile/host/timer.hpp"

View File

@@ -307,6 +307,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as corresponding "
"C block tensor data type!");
constexpr auto warp_size = get_warp_size();
// hot loop:
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -365,10 +366,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
});
});
});

View File

@@ -686,8 +686,8 @@ struct QuantGemmKernel
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(kargs.N, kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
make_tuple(kargs.QK_B, kargs.N),
make_tuple(1, kargs.stride_BQ),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
@@ -905,9 +905,9 @@ struct QuantGemmKernel
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
{i_n, 0});
make_tuple(number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
else
{

View File

@@ -52,8 +52,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
NPerBlock,
KPerBlockBQ,
NPerBlock,
VecLoadSize>;
return TileEncodingPattern::make_2d_static_tile_distribution();

View File

@@ -254,8 +254,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
static_assert(NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
static_assert(KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Bq block window has incorrect lengths for defined BqLayout!");
static_assert(is_a_col_major
@@ -313,7 +313,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BQDramTileWindowStep bq_dram_tile_window_step =
is_bq_col_major ? make_array(0, KPerBlockBQ) : make_array(KPerBlockBQ, 0);
is_bq_col_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
@@ -358,6 +358,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
if constexpr(HasHotLoop)
{
constexpr index_t tail_count =
((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) ? 1 : 2;
index_t i = 0;
do
{
@@ -403,7 +405,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
} while(i < (num_loop - tail_count));
}
// tail
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))

View File

@@ -191,28 +191,28 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
static_assert(KWarps == 1);
// # of elements per thread
static constexpr index_t X = XPerTile;
static constexpr index_t XR = 2;
static constexpr index_t Y = YPerTile;
static constexpr index_t YR = 1;
// Number of iters per warp
// MIters are indexed using (Y0, Y1)
static constexpr index_t Y0 = NIterPerWarp;
static constexpr index_t X0 = NIterPerWarp;
// # of warps in Y dim
static constexpr index_t Y1 = NWarps;
static constexpr index_t X1 = NWarps;
static constexpr index_t Y2 = WarpGemm::kN;
static constexpr index_t X2 = WarpGemm::kN;
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along Y.");
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tile_distribution_encoding<sequence<MWarps, YR>,
tuple<sequence<Y>, sequence<X0, X1, X2>>,
tuple<sequence<0, 2>, sequence<0, 2>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 2>,
sequence<2, 1>,
sequence<0, 0>>{});
}
};

View File

@@ -236,7 +236,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
// BQ DRAM window for load
auto bq_copy_dram_window =
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<KPerBlockBQ>{}),
make_tuple(number<KPerBlockBQ>{}, number<kNPerBlock>{}),
bq_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeBQDramTileDistribution<Problem>());
@@ -269,7 +269,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
BQBlockTile bq_block_tile, bq_block_tile_2;
bq_block_tile = load_tile(bq_copy_dram_window);
// move BQ to tile 1
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
// Prefill A0
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
@@ -318,7 +318,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
bq_block_tile_2 = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
// Prefill A(2i+1)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
@@ -360,7 +360,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
// Prefill A(2i+2)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);

View File

@@ -5,7 +5,7 @@
#include "test_gemm_quant_base.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/shuffle_utils.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
struct GemmConfigBase
{