Preshuffle AQ matrix in block scale gemm (#2624)

* Preshuffle AQ matrix in block scale gemm

* turns the output to fp16. Increase the repetition time.

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Cong Ma
2025-08-12 22:32:51 -06:00
committed by GitHub
parent 0f42a92fc1
commit 452791a3ba
13 changed files with 667 additions and 228 deletions

View File

@@ -156,6 +156,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
};
public:
@@ -322,6 +324,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
constexpr auto warp_size = get_warp_size();
// hot loop:
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
@@ -354,82 +357,153 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
}
});
// Need to multiply aquant with accumulated C
//
// The accumulated C tile has the standard distribution. For example
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
// [26,0], [27,0].
//
// These elements are in different rows, need to get the scale value
// for the corresponding row.
// Based on aquant's tile distribution, it can be inferred which
// lane holds the relevant scale. For example, the scales corresponding
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
//
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
if constexpr(Traits::Preshuffle)
{
// A view is created on top of the preshuffled AQ, where each row of the
// view is composed of a row from a warp tile within an AQ block tile.
// Multiple warp tile rows that belong to the same block tile are laid
// out as consecutive rows.
//
// When we need to multiply a C warp tile with an AQ warp tile, thread 0
// in the warp will load AQ_warp_tile[0], thread 1 will load
// AQ_warp_tile[1], and so on, up to thread 63, which will load
// AQ_warp_tile[63]. The VGPR file in the warp acts similarly to LDS in
// this context, but we use cross-lane operations to access the data.
// (Cross-lane operations are faster than using LDS.)
//
// Note that when the size of the AQ warp tile is smaller than the warp
// size, you need to pad the rows in the view to ensure that each thread
// can read one element.
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
constexpr uint32_t kTileRowsOfCPerThread = 4;
// MIters per warp
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
// For a warp tile of [16x16x32], take thread 0 as an example.
// Its VGPR[0] stores the value from C_tile[0,0], VGPR[1] stores
// C_tile[1,0], VGPR[2] stores C_tile[2,0], and VGPR[3] stores
// C_tile[3,0]. This means VGPR[0] should be multiplied by
// AQ_tile[0, 0], VGPR[1] by AQ_tile[1, 0], VGPR[2] by
// AQ_tile[2, 0], and VGPR[3] by AQ_tile[3, 0].
// Reg block offset based on mIter
constexpr index_t reg_block_offset =
((mIter / mIters_per_warp) * Traits::AQPerBlock);
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0]
// from thread 1, ..., and AQ_tile[3, 0] from thread 3.
auto pull_from_lane =
((threadIdx.x & (warp_size - 1)) / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
c_row) *
Traits::QScalesPerBlockRow +
kQScale;
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
constexpr index_t lane_base_offset =
(mIter % mIters_per_warp) * WarpGemm::kM;
// cross lane ops
uint32_t scale_reg_dword;
// Scale tensor offset along K
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
if constexpr(std::is_same_v<AQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
constexpr uint32_t kTileRows = 4;
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2,
__builtin_bit_cast(int, scale_reg_dword));
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) {
// Multiply by 4 because output is stored in tiles of 4
// x CNLane
constexpr uint32_t row_base =
((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane);
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f *
kA_cvt_scale * kB_cvt_scale);
});
}
else
{
// Need to multiply aquant with accumulated C
//
// The accumulated C tile has the standard distribution. For example
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
// [26,0], [27,0].
//
// These elements are in different rows, need to get the scale value
// for the corresponding row.
// Based on aquant's tile distribution, it can be inferred which
// lane holds the relevant scale. For example, the scales corresponding
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
//
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
constexpr uint32_t reg_offset_for_row_data = c_row / WarpGemm::kCMLane;
// MIters per warp
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
// Lane index to source scale from
uint32_t src_lane_idx = lane_base_offset + row_base +
(__lane_id() / WarpGemm::kN * kTileRows);
// Reg block offset based on mIter
constexpr index_t reg_block_offset =
((mIter / mIters_per_warp) * Traits::AQPerBlock);
// Directly index into thread buffer corresponding to
// desired row coefficient
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
uint32_t scale_reg_dword;
constexpr index_t lane_base_offset =
(mIter % mIters_per_warp) * WarpGemm::kM;
if constexpr(std::is_same_v<AQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
// Scale tensor offset along K
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
// Pull scale data across lanes
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword));
constexpr uint32_t kTileRows = 4;
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
c_block_tensor
.get_thread_buffer()[tbuf_offset + reg_offset_for_row_data] +=
(c_warp_tensor.get_thread_buffer()[reg_offset_for_row_data] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
});
static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) {
// Multiply by 4 because output is stored in tiles of 4
// x CNLane
constexpr uint32_t row_base =
((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane);
constexpr uint32_t reg_offset_for_row_data =
c_row / WarpGemm::kCMLane;
// Lane index to source scale from
uint32_t src_lane_idx = lane_base_offset + row_base +
(__lane_id() / WarpGemm::kN * kTileRows);
// Directly index into thread buffer corresponding to
// desired row coefficient
auto& scale_reg =
aq_block_tensor.get_thread_buffer()[src_reg_offset];
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<AQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
// Pull scale data across lanes
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
c_block_tensor
.get_thread_buffer()[tbuf_offset + reg_offset_for_row_data] +=
(c_warp_tensor.get_thread_buffer()[reg_offset_for_row_data] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
});
}
});
});
});

View File

@@ -3,11 +3,14 @@
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
@@ -104,6 +107,7 @@ struct AQuantGemmKernel
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr bool Preshuffle = GemmPipeline::Preshuffle;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
@@ -157,7 +161,7 @@ struct AQuantGemmKernel
__device__ SplitKBatchOffset(const AQuantGemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2);
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
@@ -372,14 +376,75 @@ struct AQuantGemmKernel
}
}();
const auto get_padding_size = [](index_t length, index_t alignment) {
return ck_tile::integer_least_multiple(length, alignment) - length;
};
const auto& make_preshuffled_aq_tensor_view = [&]() {
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
const auto aq_y = kargs.QK / GemmPipeline::KPerBlockAQ;
const auto aq_desc =
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
make_tuple(aq_x, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
const auto aq_pad0_desc = transform_tensor_descriptor(
aq_desc,
make_tuple(make_pass_through_transform(aq_y),
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
const auto wave_tile_size =
TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
aq_pad0_desc,
make_tuple(make_pass_through_transform(aq_y),
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto aq_pad1_desc = transform_tensor_descriptor(
aq_unmerge_pad0_desc,
make_tuple(make_pass_through_transform(aq_y),
make_pass_through_transform(wave_tile_count_x),
make_right_pad_transform(
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
const auto pad_wave_size =
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
aq_pad1_desc,
make_tuple(make_merge_transform(make_tuple(wave_tile_count_x, aq_y)),
make_pass_through_transform(pad_wave_size)),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
};
const auto& aq_tensor_view = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
if constexpr(Preshuffle)
{
return make_preshuffled_aq_tensor_view();
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
}();
const auto& b_tensor_view = [&]() {
@@ -491,16 +556,7 @@ struct AQuantGemmKernel
}
}();
const auto& aq_pad_view = [&]() {
const auto& aq_tensor_view = views.at(I1);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
return pad_tensor_view(
aq_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
// TODO: Add support for padding.
sequence<false, false>{});
}();
const auto& aq_pad_view = [&]() { return views.at(I1); }();
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I2);
@@ -543,8 +599,10 @@ struct AQuantGemmKernel
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
const AQuantGemmKernelArgs& kargs,
const index_t i_m,
const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
@@ -570,11 +628,26 @@ struct AQuantGemmKernel
const auto& aq_block_window = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return make_tile_window(
aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
{i_m, 0});
if constexpr(Preshuffle)
{
constexpr auto tile_window_width = get_warp_size();
constexpr auto tile_window_height =
TilePartitioner::MPerBlock / TilePartitioner::BlockGemmShape::WarpTile::at(I0);
auto block_m_idx = i_m / TilePartitioner::MPerBlock;
return make_tile_window(
aq_pad_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * kargs.K / TilePartitioner::BlockGemmShape::BlockTile::at(I2),
0});
}
else
{
return make_tile_window(
aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
{i_m, 0});
}
}();
const auto& b_block_window = [&]() {
@@ -633,7 +706,8 @@ struct AQuantGemmKernel
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
auto gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, kargs, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));

View File

@@ -38,12 +38,9 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using YPerTile = number<MPerBlock>;
using XPerTile = number<KPerBlockAQ>;
auto aq_copy_dram_window =
make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile(), XPerTile()),
aq_dram_block_window_tmp.get_window_lengths(),
aq_dram_block_window_tmp.get_window_origin(),
Policy::template MakeAQDramTileDistribution<Problem>());
return aq_copy_dram_window;

View File

@@ -42,6 +42,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool Preshuffle = Problem::Traits::Preshuffle;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
@@ -52,14 +53,34 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
false>;
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
VecLoadSize>;
if constexpr(Preshuffle)
{
using TileEncodingPattern =
TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock / WarpGemm::kM,
ck_tile::integer_least_multiple(
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
KPerBlockAQ,
VecLoadSize,
Preshuffle>;
return TileEncodingPattern::Make2DStaticTileDistribution();
return TileEncodingPattern::Make2DStaticTileDistribution();
}
else
{
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
Preshuffle>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>

View File

@@ -7,7 +7,6 @@
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
@@ -134,6 +133,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -254,9 +254,6 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Aq block window has incorrect lengths for defined AqLayout!");
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
@@ -312,8 +309,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// only row_major for AQ
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ);
Preshuffle ? make_array(MPerBlock / BlockGemm::WarpGemm::kM, 0)
: make_array(0, KPerBlockAQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);

View File

@@ -50,10 +50,11 @@ template <typename BlockGemmShape,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize>
index_t KPerBlockAQ,
index_t VecSize,
bool Preshuffle>
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
@@ -69,26 +70,46 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
// KWarps > 1 isn't supported
static_assert(KWarps == 1);
// # of elements per thread
static constexpr index_t X = XPerTile;
static constexpr index_t Y0 = 1;
static constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
static constexpr index_t Y2 = MWarps;
static constexpr index_t Y3 = WarpGemm::kM;
static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
tuple<sequence<1, 0>, sequence<1, 1>>,
tuple<sequence<2, 0>, sequence<0, 3>>,
sequence<1, 2>,
sequence<1, 0>>{});
if constexpr(Preshuffle)
{
// # of elements per thread
constexpr index_t X2 = KPerBlockAQ;
constexpr index_t X1 = warp_size / X2;
constexpr index_t X0 = XPerTile / warp_size;
constexpr index_t Y1 = MWarps;
constexpr index_t Y0 = YPerTile / Y1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1>, sequence<X0, X1, X2>>,
tuple<sequence<1, 0>, sequence<2, 2>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else
{
// # of elements per thread
constexpr index_t X = XPerTile;
constexpr index_t Y0 = 1;
constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
constexpr index_t Y2 = MWarps;
constexpr index_t Y3 = WarpGemm::kM;
static_assert(Y3 >= WarpGemm::kM,
"Scales for all rows must be available within the warp.");
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
tuple<sequence<1, 0>, sequence<1, 1>>,
tuple<sequence<2, 0>, sequence<0, 3>>,
sequence<1, 2>,
sequence<1, 0>>{});
}
}
};

View File

@@ -10,6 +10,7 @@ namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool Preshuffle_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
@@ -29,6 +30,7 @@ struct TileGemmAQuantTraits
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = Preshuffle_;
};
} // namespace ck_tile