mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Merge commit '4208e2898818362735e1ae9980a4cc2fea607ab4' into develop
This commit is contained in:
@@ -250,8 +250,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
|
||||
@@ -259,8 +258,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
|
||||
|
||||
@@ -4,13 +4,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_bquant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, index_t UnaryOpSize_ = 8>
|
||||
struct BlockGemmQuantBase
|
||||
struct BlockGemmAQuantBase
|
||||
{
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
@@ -67,7 +67,7 @@ struct BlockGemmQuantBase
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
|
||||
struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
{
|
||||
private:
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
@@ -103,13 +103,13 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
using I1 = number<1>;
|
||||
|
||||
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
|
||||
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
|
||||
"Error! WarpGemm's MWarp is not consistent with BlockGemmShape!");
|
||||
static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
|
||||
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
|
||||
"Error! WarpGemm's NWarp is not consistent with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
|
||||
"Error! WarpGemm's M is not consisten with BlockGemmShape!");
|
||||
"Error! WarpGemm's M is not consistent with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
|
||||
"Error! WarpGemm's N is not consisten with BlockGemmShape!");
|
||||
"Error! WarpGemm's N is not consistent with BlockGemmShape!");
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
@@ -170,7 +170,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
using Base = BlockGemmQuantBase<Problem_>;
|
||||
using Base = BlockGemmAQuantBase<Problem_>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
@@ -323,7 +323,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
[[maybe_unused]] BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"The CDataType as defined in traits should be the same as corresponding "
|
||||
"C block tensor data type!");
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
|
||||
|
||||
@@ -0,0 +1,439 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, index_t UnaryOpSize_ = 8>
|
||||
struct BlockGemmBQuantBase
|
||||
{
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
|
||||
static constexpr index_t UnaryOpSize = UnaryOpSize_;
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale)
|
||||
{
|
||||
float scale_reg_f = 0.f;
|
||||
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
scale_reg_f = ck_tile::bit_cast<float>(scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "BQDataType must be float, fp8_t or bf8_t.");
|
||||
}
|
||||
return scale_reg_f;
|
||||
}
|
||||
|
||||
// can be inherited from A
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
|
||||
const WarpWindow& warp_window)
|
||||
{
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
|
||||
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// A is block window on shared memory
|
||||
// BQ (scale tensor) is block distributed tensor.
|
||||
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
|
||||
struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
{
|
||||
private:
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
struct GemmTraits_
|
||||
{
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
// Threadblock GEMM tile size
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t BQPerBlock = KPerBlock / kQuantGroupSize;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
// number of warps along M and N for threadblock's GEMM problem size
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
|
||||
"Error! WarpGemm's MWarp is not consistent with BlockGemmShape!");
|
||||
static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
|
||||
"Error! WarpGemm's NWarp is not consistent with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
|
||||
"Error! WarpGemm's M is not consistent with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
|
||||
"Error! WarpGemm's N is not consistent with BlockGemmShape!");
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
(WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
|
||||
static_assert(kQuantGroupSize % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of kQuantGroupSize");
|
||||
static_assert(QScalesPerWarpGemmRow == 1,
|
||||
"Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
|
||||
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
|
||||
|
||||
static_assert(KPerBlock / kQuantGroupSize > 0,
|
||||
"Error! Each row of blockgemm should have a separate scale");
|
||||
|
||||
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
|
||||
"Error! Warps should cover all Block tile!");
|
||||
static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
|
||||
"Error! Warps should cover all Block tile!");
|
||||
|
||||
// Currently tested combinations (A, B, BQ)
|
||||
// 1. fp8, fp8, fp32 -> f32
|
||||
// 2. bf8, bf8, fp32 -> f32
|
||||
// 3. i4, fp8, (fp8/fp32) -> f32
|
||||
// 4. i4, bf8, (fp8/fp32) -> f32
|
||||
static_assert((std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>) &&
|
||||
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
|
||||
(std::is_same_v<BQDataType, float> ||
|
||||
std::is_same_v<BQDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<BQDataType, ck_tile::bf8_t>) &&
|
||||
(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>) &&
|
||||
std::is_same_v<CDataType, fp32_t>);
|
||||
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
};
|
||||
|
||||
public:
|
||||
using Traits = GemmTraits_<Problem_, Policy_>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Traits::BQDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
using Base = BlockGemmBQuantBase<Problem_>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
|
||||
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
|
||||
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
static constexpr uint8_t kA_cvt_scale = std::is_same_v<ADataType, pk_int4_t> ? 16 : 1;
|
||||
static constexpr uint8_t kB_cvt_scale = std::is_same_v<BDataType, pk_int4_t> ? 16 : 1;
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterwave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, KIterSeq>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterwave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, KIterSeq>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
private:
|
||||
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
|
||||
struct BlockGemmImpl
|
||||
{
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
|
||||
{
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Base::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Base::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor,
|
||||
typename BQBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
BQBlockTensor& bq_block_tensor,
|
||||
[[maybe_unused]] ASmemBlockWindow& a_block_window,
|
||||
[[maybe_unused]] BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as corresponding "
|
||||
"C block tensor data type!");
|
||||
|
||||
// hot loop:
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() =
|
||||
a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
});
|
||||
|
||||
// Need to multiply bquant with accumulated C
|
||||
//
|
||||
// The accumulated C tile has the standard distribution. For example
|
||||
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
|
||||
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
|
||||
// [26,0], [27,0].
|
||||
//
|
||||
// These elements are in different rows, need to get the scale value
|
||||
// for the corresponding row.
|
||||
// Based on bquant's tile distribution, it can be inferred which
|
||||
// lane holds the relevant scale. For example, the scales corresponding
|
||||
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
|
||||
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
|
||||
//
|
||||
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
|
||||
|
||||
constexpr index_t reg_offset = nIter * Traits::BQPerBlock + kQScale;
|
||||
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
|
||||
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f *
|
||||
kA_cvt_scale * kB_cvt_scale);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor,
|
||||
typename BQBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
BQBlockTensor& bq_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
block_gemm_impl_(c_block_tensor, bq_block_tensor, a_block_window, b_block_window);
|
||||
}
|
||||
|
||||
private:
|
||||
BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,679 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BQuantGemmProblem
|
||||
{
|
||||
CK_TILE_HOST BQuantGemmProblem() = default;
|
||||
CK_TILE_HOST BQuantGemmProblem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
index_t stride_BQ_)
|
||||
: M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
QK(QK_),
|
||||
stride_A(stride_A_),
|
||||
stride_B(stride_B_),
|
||||
stride_C(stride_C_),
|
||||
stride_BQ(stride_BQ_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t QK;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t stride_BQ;
|
||||
};
|
||||
|
||||
struct BQuantGemmHostArgs : public BQuantGemmProblem
|
||||
{
|
||||
CK_TILE_HOST BQuantGemmHostArgs() = default;
|
||||
CK_TILE_HOST BQuantGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
const void* bq_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
index_t stride_BQ_)
|
||||
: BQuantGemmProblem(M_, N_, K_, QK_, stride_A_, stride_B_, stride_C_, stride_BQ_),
|
||||
a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
bq_ptr(bq_ptr_),
|
||||
c_ptr(c_ptr_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const void* bq_ptr;
|
||||
void* c_ptr;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
struct BQuantGemmKernelArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const void* bq_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t QK;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t stride_BQ;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct BQuantGemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using BQLayout = remove_cvref_t<typename GemmPipeline::BQLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename GemmPipeline::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr BQuantGemmKernelArgs
|
||||
MakeKernelArgs(const BQuantGemmHostArgs& hostArgs)
|
||||
{
|
||||
return BQuantGemmKernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.bq_ptr,
|
||||
hostArgs.c_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.QK,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.stride_BQ,
|
||||
hostArgs.k_batch};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const BQuantGemmKernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
|
||||
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A);
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
{
|
||||
splitted_k = __builtin_amdgcn_readfirstlane(KRead);
|
||||
}
|
||||
else
|
||||
{
|
||||
splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const BQuantGemmKernelArgs& kargs)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
if(kargs.QK % GemmPipeline::GetVectorSizeBQ() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
GemmPipeline::kPadK == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
|
||||
"without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support M that is not a multiple of MPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support N that is not a multiple of NPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
GemmPipeline::kPadK == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
|
||||
"without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support N that is not a multiple of NPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support M that is not a multiple of MPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
const BQuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& bq_tensor_view = [&]() {
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.N, kargs.QK),
|
||||
make_tuple(kargs.stride_BQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, bq_tensor_view, b_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& bq_pad_view = [&]() {
|
||||
const auto& bq_tensor_view = views.at(I1);
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return pad_tensor_view(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
// TODO: Add support for padding.
|
||||
sequence<false, false>{});
|
||||
}();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I2);
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, bq_pad_view, b_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& bq_pad_view = views.at(I1);
|
||||
const auto& b_pad_view = views.at(I2);
|
||||
const auto& c_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& bq_block_window = [&]() {
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
{i_n, 0});
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
}
|
||||
}();
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, bq_block_window, b_block_window, c_block_window);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param bq_ptr input BQ pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
* @tparam DstInMemOp Destination memory operation (default: set).
|
||||
*/
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const BQuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
|
||||
a_ptr, b_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(BQuantGemmKernelArgs kargs) const
|
||||
{
|
||||
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
assert(kargs.k_batch == 1);
|
||||
RunGemm(a_ptr, b_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -28,7 +28,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
|
||||
static_assert(std::is_same_v<AQLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
return GetAQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
|
||||
return GetABQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
|
||||
{
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize;
|
||||
|
||||
static_assert(KPerBlock % QuantGroupSize == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize");
|
||||
|
||||
// Create DRAM tile window for BQ
|
||||
template <typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
|
||||
using YPerTile = number<NPerBlock>;
|
||||
using XPerTile = number<KPerBlockBQ>;
|
||||
|
||||
auto bq_copy_dram_window =
|
||||
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile(), XPerTile()),
|
||||
bq_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBQDramTileDistribution<Problem>());
|
||||
return bq_copy_dram_window;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "gemm_group_quant_utils.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy
|
||||
{
|
||||
using Base = UniversalGemmPipelineAgBgCrPolicy;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
|
||||
using Base::ATileAccessPattern;
|
||||
using Base::BTileAccessPattern;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ()
|
||||
{
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlock, KPerBlockBQ>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution()
|
||||
{
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using TileEncodingPattern = TileDistributionEncodingPatternBQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlockBQ,
|
||||
VecLoadSize>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
|
||||
static_assert(std::is_same_v<typename Problem::CDataType, float>);
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BQuantBlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,475 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseBQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_number == ck_tile::TailNumber::Full)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Even)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported tail number for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == ck_tile::TailNumber::Full)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Even)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported tail number for this operation !!!");
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem, typename Policy = GemmBQuantPipelineAgBgCrDefaultPolicy>
|
||||
struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmBQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t BQPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeBQ()
|
||||
{
|
||||
return Policy::template GetVectorSizeBQ<Problem>();
|
||||
}
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
using Base::PrefetchStages;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
return concat('_', "bquant_pipeline_AgBgCrCompV3",
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock),
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static std::string Print()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
|
||||
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
|
||||
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = 64;
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
constexpr index_t BQ_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlockBQ / (BlockSize * GetVectorSizeBQ());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
auto str = std::stringstream{};
|
||||
|
||||
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", "
|
||||
<< "BQ vector size: " << GetVectorSizeBQ() << "\n"
|
||||
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
|
||||
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
|
||||
<< ", " << "BQ buffer load inst: " << BQ_Buffer_Load_Inst_Num << "\n"
|
||||
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BQDataType,
|
||||
remove_cvref_t<typename BQDramBlockWindowTmp::DataType>>,
|
||||
"A/B/BQ Dram block window should have the same data type as appropriate "
|
||||
"([A|B|BQ]DataType) defined in Problem definition!");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_bq_col_major =
|
||||
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
|
||||
static_assert(NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
"Bq block window has incorrect lengths for defined BqLayout!");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
|
||||
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp);
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
using BQBlockTile =
|
||||
decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
|
||||
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
BQBlockTile bq_block_tile[2];
|
||||
int currIdx = 0;
|
||||
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BQDramTileWindowStep bq_dram_tile_window_step =
|
||||
is_bq_col_major ? make_array(0, KPerBlockBQ) : make_array(KPerBlockBQ, 0);
|
||||
|
||||
// DRAM prefetch (global read 0)
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
|
||||
bq_copy_dram_window,
|
||||
bq_dram_tile_window_step);
|
||||
|
||||
block_gemm(
|
||||
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
currIdx = (currIdx + 1) % 2;
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
// tail
|
||||
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
|
||||
{
|
||||
block_gemm(
|
||||
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
|
||||
bq_copy_dram_window,
|
||||
bq_dram_tile_window_step);
|
||||
block_gemm(
|
||||
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
currIdx = (currIdx + 1) % 2;
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(
|
||||
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
bq_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -8,7 +8,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename DataType, index_t YPerTile, index_t XPerTile>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAQGlobalVectorLoadSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetABQGlobalVectorLoadSize()
|
||||
{
|
||||
using I1 = number<1>;
|
||||
constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});
|
||||
@@ -164,4 +164,56 @@ struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEnc
|
||||
}
|
||||
};
|
||||
|
||||
// TODO:: might need to update
|
||||
template <typename BlockGemmShape,
|
||||
typename WarpGemm,
|
||||
index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize>
|
||||
struct TileDistributionEncodingPatternBQ : public TileDistributionEncodingPattern
|
||||
{
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
|
||||
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
|
||||
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
|
||||
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
|
||||
|
||||
static constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN);
|
||||
|
||||
static_assert(num_warps == MWarps * NWarps * KWarps);
|
||||
|
||||
// KWarps > 1 isn't supported
|
||||
static_assert(KWarps == 1);
|
||||
|
||||
// # of elements per thread
|
||||
static constexpr index_t X = XPerTile;
|
||||
static constexpr index_t XR = 2;
|
||||
|
||||
// Number of iters per warp
|
||||
// MIters are indexed using (Y0, Y1)
|
||||
static constexpr index_t Y0 = NIterPerWarp;
|
||||
|
||||
// # of warps in Y dim
|
||||
static constexpr index_t Y1 = NWarps;
|
||||
|
||||
static constexpr index_t Y2 = WarpGemm::kN;
|
||||
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, XR>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -121,4 +121,107 @@ using GemmAQuantPipelineProblem = GemmAQuantPipelineProblemBase<ADataType_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename BQDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
uint32_t QuantGroupSize_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
struct GemmBQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>
|
||||
{
|
||||
using Base = GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>;
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using typename Base::ADataType;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::CDataType;
|
||||
using typename Base::ComputeDataType;
|
||||
using BQDataType = remove_cvref_t<BQDataType_>;
|
||||
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::CLayout;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
|
||||
using Base::kBlockSize;
|
||||
|
||||
using Base::kPadK;
|
||||
using Base::kPadM;
|
||||
using Base::kPadN;
|
||||
|
||||
using Base::DoubleSmemBuffer;
|
||||
using Base::VectorLoadSize;
|
||||
|
||||
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
|
||||
|
||||
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
|
||||
static_assert(Scheduler == GemmPipelineScheduler::Intrawave);
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_bquant_problem",
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler,
|
||||
"QuantGroupSize",
|
||||
kQuantGroupSize);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
|
||||
{
|
||||
return VectorLoadSize / sizeof(BQDataType);
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename BQDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
uint32_t QuantGroupSize_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
using GemmBQuantPipelineProblem = GemmBQuantPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
BQDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
QuantGroupSize_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -28,6 +28,35 @@ struct TileGemmAQuantTraits
|
||||
using CLayout = CLayout_;
|
||||
using AQLayout = AQLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
|
||||
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool PreshuffleQuant_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
typename BQLayout_ = BLayout_>
|
||||
struct TileGemmBQuantTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
|
||||
static constexpr int _VectorSize = 16;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
using BQLayout = BQLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||
Reference in New Issue
Block a user