[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder

* hacky poc

* fix tile engine

* kill the lambda

* maybe fix test build

* more fixes

* clang-format

* save temp

* clang-format

* mostly fix examples

* clang-format

* remove dead code

* more cleanup

* fix fmha bwd build (default epilogue set/add appears to be broken)

* fix default epilogue tests but not correctness

* clang-format

* fix bquant

* clang-format

* cleanup dead code

* rearrange make windows for readability

* restore changes to IsSupportedArgument

* fix smoke-builder

* clang-format

* fixup rename class

* build fixes

* clang-format

* fix builder

* fixup

* remove set from builder tests

* fix test

* clang-format

* re-refactor the kernels

* clang-format

* fix header license

* remove memory operation from conv bwd test

* clang-format

* clang-format example,include

* clang-format test

* build fixes

* clang-format

* solve compilation error

* fix the CI

* solve compilation error

* clang format

* solve merge conflict

* solve merge conflict

* solve the gfx11 error

* solve test error

* moar build fixes

* remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Max Podkorytov
2026-01-04 03:28:14 -08:00
committed by GitHub
parent ec23be0b9d
commit e339101e9c
68 changed files with 4198 additions and 4298 deletions

View File

@@ -401,6 +401,592 @@ struct QuantGemmKernel
index_t splitted_k;
};
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
const QuantGemmKernelArgs& kargs,
const index_t k_size,
const index_t i_m)
{
// Step 1: Create tensor view for A
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, k_size),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(k_size, kargs.M),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
// Step 2: Create padded view
const auto& a_pad_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
}();
// Step 3: Create tile window
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
return a_block_window;
}
CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr,
const QuantGemmKernelArgs& kargs,
const index_t i_m,
const index_t i_n)
{
// Step 1: Create tensor view for AQ
const auto& aq_tensor_view = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
const auto aq_desc =
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
make_tuple(aq_x, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
const auto aq_pad0_desc = transform_tensor_descriptor(
aq_desc,
make_tuple(
make_pass_through_transform(aq_y),
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
const auto wave_tile_size =
GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
const auto wave_tile_count_x =
ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
aq_pad0_desc,
make_tuple(
make_pass_through_transform(aq_y),
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto aq_pad1_desc = transform_tensor_descriptor(
aq_unmerge_pad0_desc,
make_tuple(
make_pass_through_transform(aq_y),
make_pass_through_transform(wave_tile_count_x),
make_right_pad_transform(
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
const auto pad_wave_size =
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
aq_pad1_desc,
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
make_pass_through_transform(pad_wave_size)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
}
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!PreshuffleQuant)
{
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK_A),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
else // Column major AQ
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.QK_A, kargs.M),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, 0), // broadcasting over n
number<1>{},
number<1>{});
}
else
{
return nullptr;
}
}();
// Step 2: Create tile window (no padding for AQ)
const auto& aq_block_window = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_m / warp_m;
auto block_m_idx = i_m / block_m;
return make_tile_window(
aq_tensor_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * tile_window_height, 0});
}
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto block_m = TilePartitioner::MPerBlock;
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(aq_tensor_view,
make_tuple(number<block_m>{}, number<aqk_per_block>{}),
{i_m, 0});
}
else // Column major AQ
{
return make_tile_window(aq_tensor_view,
make_tuple(number<aqk_per_block>{}, number<block_m>{}),
{0, i_m});
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
return make_tile_window(
aq_tensor_view,
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
{i_m, 0});
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(aq_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return nullptr;
}
}();
return aq_block_window;
}
CK_TILE_DEVICE static auto MakeBBlockWindow(const BDataType* b_ptr,
const QuantGemmKernelArgs& kargs,
const index_t k_size,
const index_t i_n)
{
// Step 1: Create tensor view for B
const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(k_size, kargs.N),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
else
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
if constexpr(PreshuffleB)
{
index_t kFlatK =
GemmPipeline::flatKPerWarp *
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
else
{
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, k_size / 2),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
else
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, k_size),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}
}();
// Step 2: Create padded view (or flat view for PreshuffleB)
const auto& b_pad_view = [&]() {
if constexpr(PreshuffleB)
{
return b_tensor_view; // no padding for preshuffle
}
else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / 2>{}),
sequence<false, GemmPipeline::kPadK>{});
else
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
}();
// Step 3: Create tile window
const auto& b_block_window = [&]() {
if constexpr(PreshuffleB)
{
return make_tile_window(
b_pad_view,
make_tuple(number<GemmPipeline::flatNPerWarp>{},
number<GemmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
}
else
{
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / 2>{}),
{i_n, 0});
else
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
}
}();
return b_block_window;
}
CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr,
const QuantGemmKernelArgs& kargs,
const index_t i_m,
const index_t i_n)
{
// Step 1: Create tensor view for BQ
const auto& bq_tensor_view = [&]() {
if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(0, 1), // broadcasting over m
number<1>{},
number<1>{});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"PreshuffleQuant with BQuantGrouped currently only supports "
"ColumnMajor BQ layout");
return MakePreshuffledQuantTensorView<
GemmPipeline::KPerBlockBQ,
GemmPipeline::NPerBlock,
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B);
}
else
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
return nullptr;
}
}();
// Step 2: Create tile window (no padding for BQ)
const auto& bq_block_window = [&]() {
if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(bq_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_n / warp_n;
auto block_n_idx = i_n / block_n;
return make_tile_window(
bq_tensor_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_n_idx * tile_window_height, 0});
}
else
{
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(
bq_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
{0, i_n / QuantGroupSize::kN});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_tile_window(
bq_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return make_tile_window(
bq_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
else
{
return nullptr;
}
}();
return bq_block_window;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeCBlockWindow(CDataType* c_ptr,
const QuantGemmKernelArgs& kargs,
const index_t i_m,
const index_t i_n)
{
// Step 1: Create tensor view for C
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
// Step 2: Create padded view
const auto& c_pad_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
// Step 3: Create tile window
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return c_block_window;
}
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
{
if(kargs.k_batch != 1)
@@ -1143,9 +1729,7 @@ struct QuantGemmKernel
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
@@ -1157,25 +1741,22 @@ struct QuantGemmKernel
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_block_window =
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped)
{
const auto& aq_block_window = gemm_tile_windows.at(I1);
index_t m = 0;
index_t m = 0;
if constexpr(PreshuffleQuant)
{
m = kargs.M;
@@ -1185,8 +1766,7 @@ struct QuantGemmKernel
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
const auto& bq_block_window = gemm_tile_windows.at(I3);
index_t n = 0;
index_t n = 0;
if constexpr(PreshuffleQuant)
{
n = kargs.N;
@@ -1196,10 +1776,8 @@ struct QuantGemmKernel
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
const auto& aq_block_window = gemm_tile_windows.at(I1);
const auto& bq_block_window = gemm_tile_windows.at(I3);
index_t m = 0;
index_t n = 0;
index_t m = 0;
index_t n = 0;
if constexpr(PreshuffleQuant)
{
m = kargs.M;
@@ -1222,86 +1800,111 @@ struct QuantGemmKernel
}
}();
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I4);
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
if constexpr(kQuantType == QuantType::ABQuantGrouped ||
kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
// Run Epilogue Pipeline with k_batch dispatch
if(k_batch == 1)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
if constexpr(kQuantType == QuantType::ABQuantGrouped ||
kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
aq_block_window,
bq_block_window);
}
else if constexpr(kQuantType == QuantType::TensorQuant)
{
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
}
}
else if constexpr(kQuantType == QuantType::RowColQuant)
else
{
const auto& aq_block_window = gemm_tile_windows.at(I1);
const auto& bq_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
aq_block_window,
bq_block_window);
}
else if constexpr(kQuantType == QuantType::TensorQuant)
{
// TODO: why doesn't readfirstlane work here?
// const AccDataType aq_scale =
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*aq_ptr));
// const AccDataType bq_scale =
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*bq_ptr));
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
if constexpr(kQuantType == QuantType::ABQuantGrouped ||
kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
aq_block_window,
bq_block_window);
}
else if constexpr(kQuantType == QuantType::TensorQuant)
{
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
}
}
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param aq_ptr input AQ pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
[[maybe_unused]] const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
void* smem_ptr_1,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const QuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_block_window =
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = [&]() {
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
const auto& bq_block_window = gemm_tile_windows.at(I3);
index_t n = 0;
index_t n = 0;
if constexpr(PreshuffleQuant)
{
n = kargs.N;
@@ -1320,19 +1923,23 @@ struct QuantGemmKernel
}
}();
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I4);
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
// Run Epilogue Pipeline with k_batch dispatch
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else
{
return;
// throw std::runtime_error("DoubleSmemBuffer Not implemented for AQuantGrouped or
// RowColQuant"); static_assert(kQuantType == QuantType::BQuantGrouped,
// "DoubleSmemBuffer Not implemented");
if(k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
}
}
@@ -1343,16 +1950,19 @@ struct QuantGemmKernel
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
// Apply splitk offset to input pointers
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
assert(kargs.k_batch == 1);
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];

View File

@@ -374,7 +374,7 @@ struct QuantGroupedGemmKernel
CK_TILE_DEVICE static void
RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
[[maybe_unused]] const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
@@ -385,25 +385,21 @@ struct QuantGroupedGemmKernel
const index_t block_idx_n)
{
static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped");
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Create block windows using specialized methods
const auto& a_block_window =
Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_block_window =
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& bq_block_window =
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
// Run GEMM cooperatively by whole workgroup
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,
num_loop,
@@ -411,10 +407,20 @@ struct QuantGroupedGemmKernel
smem_ptr_0,
smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I4);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
// Run Epilogue Pipeline with split_k dispatch
if(kargs.k_batch == 1)
{
auto c_block_window = Base::template MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else
{
auto c_block_window =
Base::template MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
}
/**
@@ -449,16 +455,15 @@ struct QuantGroupedGemmKernel
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
// Create block windows using specialized methods
const auto& a_block_window =
Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_block_window =
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& aq_block_window =
Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
const auto& bq_block_window =
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
// Get hot-loop and tail configuration
const index_t num_loop = __builtin_amdgcn_readfirstlane(
@@ -466,51 +471,77 @@ struct QuantGroupedGemmKernel
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
if constexpr(kQuantType == QuantType::AQuantGrouped)
// Run GEMM cooperatively by whole workgroup
const auto& c_block_tile = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped)
{
return GemmPipeline{}.template operator()(a_block_window,
b_block_window,
aq_block_window,
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
return GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::RowColQuant ||
kQuantType == QuantType::TensorQuant)
{
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
}
}();
// Run Epilogue Pipeline with split_k dispatch
if(kargs.k_batch == 1)
{
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
aq_block_window,
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
auto c_block_window = Base::template MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
auto& c_block_window = gemm_tile_windows.at(Base::I4);
// Run Epilogue Pipeline
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
auto& c_block_window = gemm_tile_windows.at(Base::I4);
// Run Epilogue Pipeline
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
if constexpr(kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
aq_block_window,
bq_block_window);
}
else if constexpr(kQuantType == QuantType::TensorQuant)
{
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
}
}
else
{
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I4);
if constexpr(kQuantType == QuantType::RowColQuant)
auto c_block_window =
Base::template MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
if constexpr(kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,