mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Fix quant scale matrix layout for block scale gemm (#3079)
* Adding support for TiledPermuteN * Adding test * moving shuffle functions to common place * resolving commit hook * fix formatting
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/shuffle_utils.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
|
||||
@@ -46,7 +46,7 @@
|
||||
#include "ck_tile/host/reference/reference_topk.hpp"
|
||||
#include "ck_tile/host/reference/reference_transpose.hpp"
|
||||
#include "ck_tile/host/rotating_buffers.hpp"
|
||||
#include "ck_tile/host/shuffle_utils.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
#include "ck_tile/host/timer.hpp"
|
||||
|
||||
0
include/ck_tile/host/shuffle_utils.hpp → include/ck_tile/host/tensor_shuffle_utils.hpp
Normal file → Executable file
0
include/ck_tile/host/shuffle_utils.hpp → include/ck_tile/host/tensor_shuffle_utils.hpp
Normal file → Executable file
@@ -307,6 +307,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as corresponding "
|
||||
"C block tensor data type!");
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
|
||||
// hot loop:
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
@@ -365,10 +366,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
|
||||
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -686,8 +686,8 @@ struct QuantGemmKernel
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.N, kargs.QK_B),
|
||||
make_tuple(kargs.stride_BQ, 1),
|
||||
make_tuple(kargs.QK_B, kargs.N),
|
||||
make_tuple(1, kargs.stride_BQ),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
@@ -905,9 +905,9 @@ struct QuantGemmKernel
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
{i_n, 0});
|
||||
make_tuple(number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -52,8 +52,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlockBQ,
|
||||
NPerBlock,
|
||||
VecLoadSize>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
|
||||
@@ -254,8 +254,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
|
||||
static_assert(NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
static_assert(KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
"Bq block window has incorrect lengths for defined BqLayout!");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
@@ -313,7 +313,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BQDramTileWindowStep bq_dram_tile_window_step =
|
||||
is_bq_col_major ? make_array(0, KPerBlockBQ) : make_array(KPerBlockBQ, 0);
|
||||
is_bq_col_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ);
|
||||
|
||||
// DRAM prefetch (global read 0)
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
@@ -358,6 +358,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
constexpr index_t tail_count =
|
||||
((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) ? 1 : 2;
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
@@ -403,7 +405,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
} while(i < (num_loop - tail_count));
|
||||
}
|
||||
// tail
|
||||
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
|
||||
|
||||
@@ -191,28 +191,28 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
static_assert(KWarps == 1);
|
||||
|
||||
// # of elements per thread
|
||||
static constexpr index_t X = XPerTile;
|
||||
static constexpr index_t XR = 2;
|
||||
static constexpr index_t Y = YPerTile;
|
||||
static constexpr index_t YR = 1;
|
||||
|
||||
// Number of iters per warp
|
||||
// MIters are indexed using (Y0, Y1)
|
||||
static constexpr index_t Y0 = NIterPerWarp;
|
||||
static constexpr index_t X0 = NIterPerWarp;
|
||||
|
||||
// # of warps in Y dim
|
||||
static constexpr index_t Y1 = NWarps;
|
||||
static constexpr index_t X1 = NWarps;
|
||||
|
||||
static constexpr index_t Y2 = WarpGemm::kN;
|
||||
static constexpr index_t X2 = WarpGemm::kN;
|
||||
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
|
||||
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along Y.");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, XR>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tile_distribution_encoding<sequence<MWarps, YR>,
|
||||
tuple<sequence<Y>, sequence<X0, X1, X2>>,
|
||||
tuple<sequence<0, 2>, sequence<0, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -236,7 +236,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
// BQ DRAM window for load
|
||||
auto bq_copy_dram_window =
|
||||
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<KPerBlockBQ>{}),
|
||||
make_tuple(number<KPerBlockBQ>{}, number<kNPerBlock>{}),
|
||||
bq_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeBQDramTileDistribution<Problem>());
|
||||
|
||||
@@ -269,7 +269,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
BQBlockTile bq_block_tile, bq_block_tile_2;
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
// move BQ to tile 1
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
|
||||
|
||||
// Prefill A0
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
@@ -318,7 +318,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
bq_block_tile_2 = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
@@ -360,7 +360,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "test_gemm_quant_base.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/shuffle_utils.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user