mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Preshuffle AQ matrix in block scale gemm (#2624)
* Preshuffle AQ matrix in block scale gemm * turns the output to fp16. Increase the repetition time. --------- Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -156,6 +156,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
||||
};
|
||||
|
||||
public:
|
||||
@@ -322,6 +324,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
|
||||
// hot loop:
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
@@ -354,82 +357,153 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
}
|
||||
});
|
||||
|
||||
// Need to multiply aquant with accumulated C
|
||||
//
|
||||
// The accumulated C tile has the standard distribution. For example
|
||||
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
|
||||
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
|
||||
// [26,0], [27,0].
|
||||
//
|
||||
// These elements are in different rows, need to get the scale value
|
||||
// for the corresponding row.
|
||||
// Based on aquant's tile distribution, it can be inferred which
|
||||
// lane holds the relevant scale. For example, the scales corresponding
|
||||
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
|
||||
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
|
||||
//
|
||||
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
|
||||
if constexpr(Traits::Preshuffle)
|
||||
{
|
||||
// A view is created on top of the preshuffled AQ, where each row of the
|
||||
// view is composed of a row from a warp tile within an AQ block tile.
|
||||
// Multiple warp tile rows that belong to the same block tile are laid
|
||||
// out as consecutive rows.
|
||||
//
|
||||
// When we need to multiply a C warp tile with an AQ warp tile, thread 0
|
||||
// in the warp will load AQ_warp_tile[0], thread 1 will load
|
||||
// AQ_warp_tile[1], and so on, up to thread 63, which will load
|
||||
// AQ_warp_tile[63]. The VGPR file in the warp acts similarly to LDS in
|
||||
// this context, but we use cross-lane operations to access the data.
|
||||
// (Cross-lane operations are faster than using LDS.)
|
||||
//
|
||||
// Note that when the size of the AQ warp tile is smaller than the warp
|
||||
// size, you need to pad the rows in the view to ensure that each thread
|
||||
// can read one element.
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
constexpr uint32_t kTileRowsOfCPerThread = 4;
|
||||
|
||||
// MIters per warp
|
||||
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
// For a warp tile of [16x16x32], take thread 0 as an example.
|
||||
// Its VGPR[0] stores the value from C_tile[0,0], VGPR[1] stores
|
||||
// C_tile[1,0], VGPR[2] stores C_tile[2,0], and VGPR[3] stores
|
||||
// C_tile[3,0]. This means VGPR[0] should be multiplied by
|
||||
// AQ_tile[0, 0], VGPR[1] by AQ_tile[1, 0], VGPR[2] by
|
||||
// AQ_tile[2, 0], and VGPR[3] by AQ_tile[3, 0].
|
||||
|
||||
// Reg block offset based on mIter
|
||||
constexpr index_t reg_block_offset =
|
||||
((mIter / mIters_per_warp) * Traits::AQPerBlock);
|
||||
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0]
|
||||
// from thread 1, ..., and AQ_tile[3, 0] from thread 3.
|
||||
auto pull_from_lane =
|
||||
((threadIdx.x & (warp_size - 1)) / Traits::WarpGemm::kN *
|
||||
kTileRowsOfCPerThread +
|
||||
c_row) *
|
||||
Traits::QScalesPerBlockRow +
|
||||
kQScale;
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
|
||||
|
||||
constexpr index_t lane_base_offset =
|
||||
(mIter % mIters_per_warp) * WarpGemm::kM;
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
// Scale tensor offset along K
|
||||
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
|
||||
if constexpr(std::is_same_v<AQDataType, float>)
|
||||
{
|
||||
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
|
||||
constexpr uint32_t kTileRows = 4;
|
||||
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
pull_from_lane << 2,
|
||||
__builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
|
||||
|
||||
static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) {
|
||||
// Multiply by 4 because output is stored in tiles of 4
|
||||
// x CNLane
|
||||
constexpr uint32_t row_base =
|
||||
((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
|
||||
((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane);
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f *
|
||||
kA_cvt_scale * kB_cvt_scale);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Need to multiply aquant with accumulated C
|
||||
//
|
||||
// The accumulated C tile has the standard distribution. For example
|
||||
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
|
||||
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
|
||||
// [26,0], [27,0].
|
||||
//
|
||||
// These elements are in different rows, need to get the scale value
|
||||
// for the corresponding row.
|
||||
// Based on aquant's tile distribution, it can be inferred which
|
||||
// lane holds the relevant scale. For example, the scales corresponding
|
||||
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
|
||||
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
|
||||
//
|
||||
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
|
||||
|
||||
constexpr uint32_t reg_offset_for_row_data = c_row / WarpGemm::kCMLane;
|
||||
// MIters per warp
|
||||
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
|
||||
|
||||
// Lane index to source scale from
|
||||
uint32_t src_lane_idx = lane_base_offset + row_base +
|
||||
(__lane_id() / WarpGemm::kN * kTileRows);
|
||||
// Reg block offset based on mIter
|
||||
constexpr index_t reg_block_offset =
|
||||
((mIter / mIters_per_warp) * Traits::AQPerBlock);
|
||||
|
||||
// Directly index into thread buffer corresponding to
|
||||
// desired row coefficient
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
|
||||
uint32_t scale_reg_dword;
|
||||
constexpr index_t lane_base_offset =
|
||||
(mIter % mIters_per_warp) * WarpGemm::kM;
|
||||
|
||||
if constexpr(std::is_same_v<AQDataType, float>)
|
||||
{
|
||||
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
// Scale tensor offset along K
|
||||
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
|
||||
|
||||
// Pull scale data across lanes
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword));
|
||||
constexpr uint32_t kTileRows = 4;
|
||||
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
|
||||
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
c_block_tensor
|
||||
.get_thread_buffer()[tbuf_offset + reg_offset_for_row_data] +=
|
||||
(c_warp_tensor.get_thread_buffer()[reg_offset_for_row_data] *
|
||||
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
|
||||
});
|
||||
static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) {
|
||||
// Multiply by 4 because output is stored in tiles of 4
|
||||
// x CNLane
|
||||
constexpr uint32_t row_base =
|
||||
((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
|
||||
((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane);
|
||||
|
||||
constexpr uint32_t reg_offset_for_row_data =
|
||||
c_row / WarpGemm::kCMLane;
|
||||
|
||||
// Lane index to source scale from
|
||||
uint32_t src_lane_idx = lane_base_offset + row_base +
|
||||
(__lane_id() / WarpGemm::kN * kTileRows);
|
||||
|
||||
// Directly index into thread buffer corresponding to
|
||||
// desired row coefficient
|
||||
auto& scale_reg =
|
||||
aq_block_tensor.get_thread_buffer()[src_reg_offset];
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
if constexpr(std::is_same_v<AQDataType, float>)
|
||||
{
|
||||
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
|
||||
// Pull scale data across lanes
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
|
||||
|
||||
c_block_tensor
|
||||
.get_thread_buffer()[tbuf_offset + reg_offset_for_row_data] +=
|
||||
(c_warp_tensor.get_thread_buffer()[reg_offset_for_row_data] *
|
||||
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3,11 +3,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -104,6 +107,7 @@ struct AQuantGemmKernel
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool Preshuffle = GemmPipeline::Preshuffle;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
|
||||
@@ -157,7 +161,7 @@ struct AQuantGemmKernel
|
||||
__device__ SplitKBatchOffset(const AQuantGemmKernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2);
|
||||
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
|
||||
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
|
||||
|
||||
@@ -372,14 +376,75 @@ struct AQuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto get_padding_size = [](index_t length, index_t alignment) {
|
||||
return ck_tile::integer_least_multiple(length, alignment) - length;
|
||||
};
|
||||
|
||||
const auto& make_preshuffled_aq_tensor_view = [&]() {
|
||||
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_y = kargs.QK / GemmPipeline::KPerBlockAQ;
|
||||
|
||||
const auto aq_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
|
||||
make_tuple(aq_x, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
|
||||
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_pad0_desc = transform_tensor_descriptor(
|
||||
aq_desc,
|
||||
make_tuple(make_pass_through_transform(aq_y),
|
||||
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
|
||||
const auto wave_tile_size =
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
|
||||
const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
|
||||
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
aq_pad0_desc,
|
||||
make_tuple(make_pass_through_transform(aq_y),
|
||||
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto aq_pad1_desc = transform_tensor_descriptor(
|
||||
aq_unmerge_pad0_desc,
|
||||
make_tuple(make_pass_through_transform(aq_y),
|
||||
make_pass_through_transform(wave_tile_count_x),
|
||||
make_right_pad_transform(
|
||||
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
const auto pad_wave_size =
|
||||
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
||||
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
|
||||
aq_pad1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(wave_tile_count_x, aq_y)),
|
||||
make_pass_through_transform(pad_wave_size)),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
|
||||
};
|
||||
|
||||
const auto& aq_tensor_view = [&]() {
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
if constexpr(Preshuffle)
|
||||
{
|
||||
return make_preshuffled_aq_tensor_view();
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
@@ -491,16 +556,7 @@ struct AQuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& aq_pad_view = [&]() {
|
||||
const auto& aq_tensor_view = views.at(I1);
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
return pad_tensor_view(
|
||||
aq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
// TODO: Add support for padding.
|
||||
sequence<false, false>{});
|
||||
}();
|
||||
const auto& aq_pad_view = [&]() { return views.at(I1); }();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I2);
|
||||
@@ -543,8 +599,10 @@ struct AQuantGemmKernel
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
||||
const AQuantGemmKernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& aq_pad_view = views.at(I1);
|
||||
@@ -570,11 +628,26 @@ struct AQuantGemmKernel
|
||||
|
||||
const auto& aq_block_window = [&]() {
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
{i_m, 0});
|
||||
if constexpr(Preshuffle)
|
||||
{
|
||||
constexpr auto tile_window_width = get_warp_size();
|
||||
constexpr auto tile_window_height =
|
||||
TilePartitioner::MPerBlock / TilePartitioner::BlockGemmShape::WarpTile::at(I0);
|
||||
auto block_m_idx = i_m / TilePartitioner::MPerBlock;
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_m_idx * kargs.K / TilePartitioner::BlockGemmShape::BlockTile::at(I2),
|
||||
0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
@@ -633,7 +706,8 @@ struct AQuantGemmKernel
|
||||
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
@@ -38,12 +38,9 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
using YPerTile = number<MPerBlock>;
|
||||
using XPerTile = number<KPerBlockAQ>;
|
||||
|
||||
auto aq_copy_dram_window =
|
||||
make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile(), XPerTile()),
|
||||
aq_dram_block_window_tmp.get_window_lengths(),
|
||||
aq_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeAQDramTileDistribution<Problem>());
|
||||
return aq_copy_dram_window;
|
||||
|
||||
@@ -42,6 +42,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
@@ -52,14 +53,34 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
false>;
|
||||
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize>;
|
||||
if constexpr(Preshuffle)
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock / WarpGemm::kM,
|
||||
ck_tile::integer_least_multiple(
|
||||
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
Preshuffle>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
Preshuffle>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
@@ -134,6 +133,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -254,9 +254,6 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
|
||||
static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
"Aq block window has incorrect lengths for defined AqLayout!");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
@@ -312,8 +309,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
|
||||
// only row_major for AQ
|
||||
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
|
||||
is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ);
|
||||
Preshuffle ? make_array(MPerBlock / BlockGemm::WarpGemm::kM, 0)
|
||||
: make_array(0, KPerBlockAQ);
|
||||
|
||||
// DRAM prefetch (global read 0)
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
@@ -50,10 +50,11 @@ template <typename BlockGemmShape,
|
||||
index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize>
|
||||
index_t KPerBlockAQ,
|
||||
index_t VecSize,
|
||||
bool Preshuffle>
|
||||
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
|
||||
{
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
@@ -69,26 +70,46 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
|
||||
// KWarps > 1 isn't supported
|
||||
static_assert(KWarps == 1);
|
||||
|
||||
// # of elements per thread
|
||||
static constexpr index_t X = XPerTile;
|
||||
|
||||
static constexpr index_t Y0 = 1;
|
||||
static constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
|
||||
static constexpr index_t Y2 = MWarps;
|
||||
static constexpr index_t Y3 = WarpGemm::kM;
|
||||
static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
|
||||
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
|
||||
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>,
|
||||
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 1>>,
|
||||
tuple<sequence<2, 0>, sequence<0, 3>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{});
|
||||
if constexpr(Preshuffle)
|
||||
{
|
||||
// # of elements per thread
|
||||
constexpr index_t X2 = KPerBlockAQ;
|
||||
constexpr index_t X1 = warp_size / X2;
|
||||
constexpr index_t X0 = XPerTile / warp_size;
|
||||
|
||||
constexpr index_t Y1 = MWarps;
|
||||
constexpr index_t Y0 = YPerTile / Y1;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>,
|
||||
tuple<sequence<Y0, Y1>, sequence<X0, X1, X2>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 2>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// # of elements per thread
|
||||
constexpr index_t X = XPerTile;
|
||||
|
||||
constexpr index_t Y0 = 1;
|
||||
constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
|
||||
constexpr index_t Y2 = MWarps;
|
||||
constexpr index_t Y3 = WarpGemm::kM;
|
||||
static_assert(Y3 >= WarpGemm::kM,
|
||||
"Scales for all rows must be available within the warp.");
|
||||
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
|
||||
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>,
|
||||
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 1>>,
|
||||
tuple<sequence<2, 0>, sequence<0, 3>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ namespace ck_tile {
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool Preshuffle_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
@@ -29,6 +30,7 @@ struct TileGemmAQuantTraits
|
||||
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = Preshuffle_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user