Sync with develop

This commit is contained in:
KenSCLin
2025-12-10 09:59:00 +00:00
parent ddfea2a784
commit 4f207de1b8
8 changed files with 257 additions and 202 deletions

View File

@@ -18,11 +18,11 @@ void abquant_quantgrouped_instance_factory(
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type_layout<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
@@ -33,49 +33,40 @@ void abquant_quantgrouped_instance_factory(
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type_layout<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"abquant",
"non-preshuffleb",
"non-preshufflequant",
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8",
"abquant",
"non-preshuffleb",
"non-preshufflequant",
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp8i4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8i4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
}

View File

@@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[])
.insert("prec",
"fp8",
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
"or bf8i4; for ABQuant: fp8, bf8, i4fp8, or i4bf8")
"or bf8i4; for ABQuant: fp8, bf8")
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

View File

@@ -162,11 +162,11 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
constexpr bool TiledPermuteN =
(QuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN;
(BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN;
if(s.log_level_ > 0)
{
printf(
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, QuantGroupSize::kN);
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
}
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
@@ -440,31 +440,30 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
}
ck_tile::index_t AQK, BQK, BQN;
ck_tile::index_t AQK, BQK, BQN = 0;
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
AQK = K / AQuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
BQN = 0;
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / BQuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
BQN = N / BQuantGroupSize::kN; // Group quantization: BQN = N / GroupSize
BQN = ck_tile::integer_divide_ceil(N, BQuantGroupSize::kN);
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
AQK = K / AQuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
BQK = K / BQuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
BQN = N / BQuantGroupSize::kN; // Group quantization: BQN = N / GroupSize
BQN = ck_tile::integer_divide_ceil(N, BQuantGroupSize::kN);
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
BQN = 0;
BQN = 1;
}
else
{
@@ -540,16 +539,12 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
// Create BQ tensor with appropriate shape
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
@@ -715,7 +710,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(GemmConfig::PreshuffleB)
{
if constexpr(GemmConfig::TiledMMAPermuteN && QuantGroupSize::kN == 1)
if constexpr(GemmConfig::TiledMMAPermuteN && BQuantGroupSize::kN == 1)
{
printf("PreshuffleB with TiledMMAPermuteN\n");
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
@@ -742,10 +737,10 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
QuantMode == ck_tile::QuantType::TensorQuant)
{
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN &&
QuantGroupSize::kN == 1)
BQuantGroupSize::kN == 1)
{
ck_tile::HostTensor<BQDataType> bq_permuted_host =
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, QuantGroupSize::kN);
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, BQuantGroupSize::kN);
if constexpr(GemmConfig::PreshuffleQuant)
{
@@ -895,66 +890,6 @@ template <typename GemmConfig,
typename AQuantGroupSize,
typename BQuantGroupSize,
ck_tile::QuantType QuantMode>
int run_gemm_example_prec_type_layout(const ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(QuantMode == ck_tile::QuantType::ABQuantGrouped && GemmConfig::PreshuffleB)
{
throw std::runtime_error("Preshuffling weight matrix is not supported for ABQuant");
}
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
{
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Col{}, Row{}, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Row{}, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for A.");
}
return 0;
}
template <typename GemmConfig,
typename TypeConfig,
typename QuantGroupSize,
ck_tile::QuantType QuantMode>
int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
@@ -979,19 +914,22 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
QuantGroupSize,
QuantGroupSize,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
}
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant)
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
!GemmConfig::PreshuffleQuant)
{
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
QuantGroupSize,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Row{}, Row{}, Col{}, Row{});
}
@@ -999,24 +937,24 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
QuantGroupSize,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Col{}, Row{}, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
}
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant)
{
if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
QuantGroupSize,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Col{}, Col{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
@@ -1029,3 +967,16 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
return 0;
}
template <typename GemmConfig,
typename TypeConfig,
typename QuantGroupSize,
ck_tile::QuantType QuantMode>
int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
{
return run_gemm_example_prec_type<GemmConfig,
TypeConfig,
QuantGroupSize,
QuantGroupSize,
QuantMode>(arg_parser);
}

View File

@@ -33,6 +33,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
@@ -79,13 +80,13 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(BQuantGroupSize::kK, WarpGemm::kK);
integer_divide_ceil(WarpGemm::kK, BQuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static_assert(BQuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
static_assert(QScalesPerWarpGemmRow > 1,
static_assert(QScalesPerWarpGemmRow == 1,
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
@@ -132,6 +133,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using Base = BlockGemmQuantBase;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
@@ -152,6 +156,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant;
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
static constexpr auto a_warp_y_lengths =
@@ -235,7 +241,6 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
{
public:
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
@@ -247,12 +252,20 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
// If B datatype were pkint4 it would be converted prior to storing in LDS
load_int4_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
}
// C += A * B
@@ -267,7 +280,6 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
[[maybe_unused]] ASmemBlockWindow& a_block_window,
[[maybe_unused]] BSmemBlockWindow& b_block_window)
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as corresponding "
"C block tensor data type!");
@@ -303,47 +315,78 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
// a_scale
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
// Multiply bquant with accumulated C
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::BQuantGroupSize::kN * Traits::KQPerBlock +
kQScale;
if constexpr(PreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<BQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
return nIter * Traits::KQPerBlock + kQScale;
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
}();
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_reg_f =
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(scale_reg);
// cross lane ops to get the value of scale_reg.
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_reg_f =
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(scale_reg);
float b_scale_reg_f =
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(
gathered_scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f *
b_scale_reg_f);
});
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f *
b_scale_reg_f);
});
}
else
{
// Multiply bquant with accumulated C
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >=
(NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::BQuantGroupSize::kN *
Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_reg_f =
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f *
b_scale_reg_f);
});
}
});
});
});
@@ -357,11 +400,16 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
MakeCBlockTile();
}
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}
// C += A * B

View File

@@ -125,6 +125,9 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using Base = BlockGemmQuantBase;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;

View File

@@ -34,6 +34,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
@@ -98,6 +101,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
using Base::PrefetchStages;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
@@ -110,7 +116,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName());
// clang-format on
}
@@ -142,7 +148,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
constexpr index_t AQ_Buffer_Load_Inst_Num =
MPerBlock * KPerBlockAQ / (BlockSize * GetVectorSizeAQ());
constexpr index_t BQ_Buffer_Load_Inst_Num =
NPerBlock * KPerBlockBQ / (BlockSize * GetVectorSizeBQ());
NPerBlockBQ * KPerBlockBQ / (BlockSize * GetVectorSizeBQ());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
@@ -187,6 +193,26 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
{
using Base = PipelineImplBase;
template <typename ADramWindow, typename ABlockTile_>
CK_TILE_DEVICE static void LoadAndConvertATile(ABlockTile_& a_block_tile,
const ADramWindow& a_dram_window)
{
using DestDataType = typename ABlockTile_::DataType;
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
}
template <typename BDramWindow, typename BBlockTile_>
CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile,
const BDramWindow& b_dram_window)
{
using DestDataType = typename BBlockTile_::DataType;
using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
@@ -221,12 +247,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_aq_col_major =
std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_bq_col_major =
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
constexpr bool is_bq_row_major =
std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
@@ -240,13 +263,23 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
static_assert(
PreshuffleQuant ||
(is_bq_row_major
? (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])),
"Bq block window has incorrect lengths for defined BqLayout!");
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex;
using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex;
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
// Note: BDataType PkInt4 gets converted during loading, before going to LDS
auto&& [a_lds_block, b_lds_block] =
Base::template GetABLdsTensorViews<ADataType, OverrideBDataType>(p_smem);
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
@@ -290,20 +323,28 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// only row_major for AQ
const AQDramTileWindowStep aq_dram_tile_window_step =
PreshuffleQuant ? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
BlockGemm::WarpGemm::kM,
0)
: make_array(0, KPerBlockAQ);
PreshuffleQuant
? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
BlockGemm::WarpGemm::kM,
0)
: (is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ));
const BQDramTileWindowStep bq_dram_tile_window_step =
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0)
: is_bq_col_major ? make_array(0, KPerBlockBQ)
: make_array(KPerBlockBQ, 0);
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
: make_array(0, KPerBlockBQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
// Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
// B tile gets converted to A datatype during loading
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(
aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step);
Base::GlobalPrefetch(
@@ -311,7 +352,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -323,7 +364,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -335,12 +376,18 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
// Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
__builtin_amdgcn_sched_barrier(0);
@@ -353,7 +400,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
{
block_sync_lds();
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -364,9 +411,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
// Note: BDataType PkInt4 gets converted during loading earlier
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -376,8 +424,16 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// Base::GlobalPrefetch(a_block_tile, a_copy_dram_window,
// a_dram_tile_window_step);
// Base::GlobalPrefetch(b_block_tile, b_copy_dram_window,
// b_dram_tile_window_step);
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2],
aq_copy_dram_window,
aq_dram_tile_window_step);
@@ -395,7 +451,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
__builtin_amdgcn_sched_barrier(0);
i += 1;
@@ -440,7 +497,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
// Note: BDataType gets converted during loading from PkInt4
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -450,7 +508,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(c_block_tile,
aq_block_tile[currIdx],
bq_block_tile[currIdx],
@@ -499,6 +558,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
///
/// @param a_dram_block_window_tmp Block window for A tensor in DRAM
/// @param b_dram_block_window_tmp Block window for B tensor in DRAM
/// @param aq_dram_block_window_tmp Block window for AQ (quantization scale) tensor in DRAM
/// @param bq_dram_block_window_tmp Block window for BQ (quantization scale) tensor in DRAM
/// @param num_loop Number of main loop iterations (calculated on device)
/// @param has_hot_loop Whether the pipeline has a hot loop (calculated on device)
@@ -528,7 +588,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
// Note: BDataType PkInt4 gets converted during loading
[](const OverrideBDataType& b) { return b; },
aq_dram_block_window_tmp,
bq_dram_block_window_tmp,
m,

View File

@@ -80,9 +80,10 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
NPerBlockBQ,
KPerBlockBQ,
Problem::BQuantGroupSize::kN>;
KPerBlockBQ, // Logical K dimension
NPerBlockBQ, // Logical N dimension
Problem::BQuantGroupSize::kN,
BQLayout>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}

View File

@@ -52,7 +52,7 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
using AQuantGroupSize = AQuantGroupSize_;
using BQuantGroupSize = BQuantGroupSize_;
// For backward compatibility
using QuantGroupSize = AQuantGroupSize_;
using QuantGroupSize = BQuantGroupSize_;
using typename Base::ALayout;
using typename Base::BLayout;