mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)
* initial poc * factor out common parts in operator() * cv4 * rest of the universal gemm pipelines * fix test * remove boilerplate from tile engine * fix example * fix example * format * fix tests build for gemm * remove base pipeline codegen from gemm instance builder * unify v3 logic with the rest of universal gemm pipelines * fix build for multi abd test * fix test gemm multi d * fix build for weight preshuffle * fix grouped gemm test * fix grouped gemm multi d test * fix grouped gemm preshuffle * fix grouped gemm example except for quant * fix gemm preshuffle * fix splitk 2 stage example * fix batched gemm example * fix multid example * fix multiabd example * fix batched gemm test * fixup * fix examples build * fix grouped gemm test build * fix smoke builder * hacky poc * fix tile engine * kill the lambda * maybe fix test build * more fixes * clang-format * save temp * clang-format * mostly fix examples * clang-format * remove dead code * more cleanup * fix fmha bwd build (default epilogue set/add appears to be broken) * fix default epilogue tests but not correctness * clang-format * fix bquant * clang-format * cleanup dead code * rearrange make windows for readability * restore changes to IsSupportedArgument * fix smoke-builder * clang-format * fixup rename class * build fixes * clang-format * fix builder * fixup * remove set from builder tests * fix test * clang-format * re-refactor the kernels * clang-format * fix header license * remove memory operation from conv bwd test * clang-format * clang-format example,include * clang-format test * build fixes * clang-format * solve compilation error * fix the CI * solve compilation error * clang format * solve merge conflict * solve merge conflict * solve the gfx11 error * solve test error * moar build fixes * remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -401,6 +401,592 @@ struct QuantGemmKernel
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_m)
|
||||
{
|
||||
// Step 1: Create tensor view for A
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, k_size),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(k_size, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 3: Create tile window
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for AQ
|
||||
const auto& aq_tensor_view = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
|
||||
make_tuple(aq_x, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
|
||||
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_pad0_desc = transform_tensor_descriptor(
|
||||
aq_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(aq_y),
|
||||
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
|
||||
const auto wave_tile_size =
|
||||
GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
|
||||
const auto wave_tile_count_x =
|
||||
ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
|
||||
|
||||
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
aq_pad0_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(aq_y),
|
||||
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto aq_pad1_desc = transform_tensor_descriptor(
|
||||
aq_unmerge_pad0_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(aq_y),
|
||||
make_pass_through_transform(wave_tile_count_x),
|
||||
make_right_pad_transform(
|
||||
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
const auto pad_wave_size =
|
||||
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
||||
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
|
||||
aq_pad1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
|
||||
make_pass_through_transform(pad_wave_size)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
|
||||
}
|
||||
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped) &&
|
||||
!PreshuffleQuant)
|
||||
{
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK_A),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else // Column major AQ
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.QK_A, kargs.M),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, 0), // broadcasting over n
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create tile window (no padding for AQ)
|
||||
const auto& aq_block_window = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_m / warp_m;
|
||||
auto block_m_idx = i_m / block_m;
|
||||
return make_tile_window(
|
||||
aq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_m_idx * tile_window_height, 0});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(aq_tensor_view,
|
||||
make_tuple(number<block_m>{}, number<aqk_per_block>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else // Column major AQ
|
||||
{
|
||||
return make_tile_window(aq_tensor_view,
|
||||
make_tuple(number<aqk_per_block>{}, number<block_m>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||
return make_tile_window(
|
||||
aq_tensor_view,
|
||||
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_tile_window(aq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}();
|
||||
|
||||
return aq_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto MakeBBlockWindow(const BDataType* b_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for B
|
||||
const auto& b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(k_size, kargs.N),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::flatKPerWarp *
|
||||
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, k_size / 2),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
else
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, k_size),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create padded view (or flat view for PreshuffleB)
|
||||
const auto& b_pad_view = [&]() {
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
return b_tensor_view; // no padding for preshuffle
|
||||
}
|
||||
else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / 2>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
else
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 3: Create tile window
|
||||
const auto& b_block_window = [&]() {
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<GemmPipeline::flatNPerWarp>{},
|
||||
number<GemmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / 2>{}),
|
||||
{i_n, 0});
|
||||
else
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for BQ
|
||||
const auto& bq_tensor_view = [&]() {
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(0, 1), // broadcasting over m
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
|
||||
"PreshuffleQuant with BQuantGrouped currently only supports "
|
||||
"ColumnMajor BQ layout");
|
||||
|
||||
return MakePreshuffledQuantTensorView<
|
||||
GemmPipeline::KPerBlockBQ,
|
||||
GemmPipeline::NPerBlock,
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
|
||||
GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B);
|
||||
}
|
||||
else
|
||||
{
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
|
||||
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
|
||||
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
|
||||
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
|
||||
make_tuple(kargs.stride_BQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create tile window (no padding for BQ)
|
||||
const auto& bq_block_window = [&]() {
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_tile_window(bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_n / warp_n;
|
||||
auto block_n_idx = i_n / block_n;
|
||||
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx * tile_window_height, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}();
|
||||
|
||||
return bq_block_window;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeCBlockWindow(CDataType* c_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for C
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& c_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 3: Create tile window
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return c_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
@@ -1143,9 +1729,7 @@ struct QuantGemmKernel
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
* @tparam DstInMemOp Destination memory operation (default: set).
|
||||
*/
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
@@ -1157,25 +1741,22 @@ struct QuantGemmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(I1);
|
||||
index_t m = 0;
|
||||
index_t m = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
m = kargs.M;
|
||||
@@ -1185,8 +1766,7 @@ struct QuantGemmKernel
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
index_t n = 0;
|
||||
index_t n = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
n = kargs.N;
|
||||
@@ -1196,10 +1776,8 @@ struct QuantGemmKernel
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
index_t m = 0;
|
||||
index_t n = 0;
|
||||
index_t m = 0;
|
||||
index_t n = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
m = kargs.M;
|
||||
@@ -1222,86 +1800,111 @@ struct QuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I4);
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
|
||||
if constexpr(kQuantType == QuantType::ABQuantGrouped ||
|
||||
kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
// Run Epilogue Pipeline with k_batch dispatch
|
||||
if(k_batch == 1)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
if constexpr(kQuantType == QuantType::ABQuantGrouped ||
|
||||
kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
else
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
// TODO: why doesn't readfirstlane work here?
|
||||
// const AccDataType aq_scale =
|
||||
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*aq_ptr));
|
||||
// const AccDataType bq_scale =
|
||||
// __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*bq_ptr));
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
if constexpr(kQuantType == QuantType::ABQuantGrouped ||
|
||||
kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param aq_ptr input AQ pointer
|
||||
* @param bq_ptr input BQ pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
* @tparam DstInMemOp Destination memory operation (default: set).
|
||||
*/
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
[[maybe_unused]] const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
void* smem_ptr_1,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = [&]() {
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
index_t n = 0;
|
||||
index_t n = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
n = kargs.N;
|
||||
@@ -1320,19 +1923,23 @@ struct QuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I4);
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
|
||||
// Run Epilogue Pipeline with k_batch dispatch
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return;
|
||||
// throw std::runtime_error("DoubleSmemBuffer Not implemented for AQuantGrouped or
|
||||
// RowColQuant"); static_assert(kQuantType == QuantType::BQuantGrouped,
|
||||
// "DoubleSmemBuffer Not implemented");
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1343,16 +1950,19 @@ struct QuantGemmKernel
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
|
||||
// Apply splitk offset to input pointers
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
|
||||
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
assert(kargs.k_batch == 1);
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
|
||||
@@ -374,7 +374,7 @@ struct QuantGroupedGemmKernel
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
[[maybe_unused]] const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
@@ -385,25 +385,21 @@ struct QuantGroupedGemmKernel
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped");
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& bq_block_window =
|
||||
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
|
||||
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
// Run GEMM cooperatively by whole workgroup
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
@@ -411,10 +407,20 @@ struct QuantGroupedGemmKernel
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
// Run Epilogue Pipeline with split_k dispatch
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = Base::template MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window =
|
||||
Base::template MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -449,16 +455,15 @@ struct QuantGroupedGemmKernel
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& aq_block_window =
|
||||
Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& bq_block_window =
|
||||
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
// Get hot-loop and tail configuration
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
@@ -466,51 +471,77 @@ struct QuantGroupedGemmKernel
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
// Run GEMM cooperatively by whole workgroup
|
||||
const auto& c_block_tile = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
aq_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant ||
|
||||
kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
}
|
||||
}();
|
||||
|
||||
// Run Epilogue Pipeline with split_k dispatch
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
aq_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
auto c_block_window = Base::template MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
auto c_block_window =
|
||||
Base::template MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
|
||||
Reference in New Issue
Block a user